1use anyhow::{Context as _, Result, anyhow, bail};
2use arrayvec::ArrayVec;
3use chrono::TimeDelta;
4use client::{Client, EditPredictionUsage, UserStore};
5use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
6use cloud_llm_client::{
7 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
8 ZED_VERSION_HEADER_NAME,
9};
10use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
11use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
12use collections::HashMap;
13use edit_prediction_context::{
14 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
15 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
16 SyntaxIndex, SyntaxIndexState,
17};
18use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
19use futures::AsyncReadExt as _;
20use futures::channel::{mpsc, oneshot};
21use gpui::http_client::{AsyncBody, Method};
22use gpui::{
23 App, AsyncApp, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task,
24 WeakEntity, http_client, prelude::*,
25};
26use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, Point, ToOffset as _, ToPoint};
27use language::{BufferSnapshot, OffsetRangeExt};
28use language_model::{LlmApiToken, RefreshLlmTokenListener};
29use lsp::DiagnosticSeverity;
30use open_ai::FunctionDefinition;
31use project::{Project, ProjectPath};
32use release_channel::AppVersion;
33use serde::de::DeserializeOwned;
34use std::collections::{VecDeque, hash_map};
35
36use std::fmt::Write;
37use std::ops::Range;
38use std::path::Path;
39use std::str::FromStr as _;
40use std::sync::{Arc, LazyLock};
41use std::time::{Duration, Instant};
42use std::{env, mem};
43use thiserror::Error;
44use util::rel_path::RelPathBuf;
45use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
46use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
47
48pub mod assemble_excerpts;
49mod prediction;
50mod provider;
51pub mod retrieval_search;
52mod sweep_ai;
53pub mod udiff;
54mod xml_edits;
55
56use crate::assemble_excerpts::assemble_excerpts;
57pub use crate::prediction::EditPrediction;
58pub use crate::prediction::EditPredictionId;
59pub use provider::ZetaEditPredictionProvider;
60
61/// Maximum number of events to track.
62const EVENT_COUNT_MAX_SWEEP: usize = 6;
63const EVENT_COUNT_MAX_ZETA: usize = 16;
64const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
65
66pub struct SweepFeatureFlag;
67
68impl FeatureFlag for SweepFeatureFlag {
69 const NAME: &str = "sweep-ai";
70}
71pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
72 max_bytes: 512,
73 min_bytes: 128,
74 target_before_cursor_over_total_bytes: 0.5,
75};
76
77pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
78 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
79
80pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
81 excerpt: DEFAULT_EXCERPT_OPTIONS,
82};
83
84pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
85 EditPredictionContextOptions {
86 use_imports: true,
87 max_retrieved_declarations: 0,
88 excerpt: DEFAULT_EXCERPT_OPTIONS,
89 score: EditPredictionScoreOptions {
90 omit_excerpt_overlaps: true,
91 },
92 };
93
94pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
95 context: DEFAULT_CONTEXT_OPTIONS,
96 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
97 max_diagnostic_bytes: 2048,
98 prompt_format: PromptFormat::DEFAULT,
99 file_indexing_parallelism: 1,
100 buffer_change_grouping_interval: Duration::from_secs(1),
101};
102
103static USE_OLLAMA: LazyLock<bool> =
104 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
105static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
106 env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
107 "qwen3-coder:30b".to_string()
108 } else {
109 "yqvev8r3".to_string()
110 })
111});
112static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
113 match env::var("ZED_ZETA2_MODEL").as_deref() {
114 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
115 Ok(model) => model,
116 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
117 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
118 }
119 .to_string()
120});
121static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
122 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
123 if *USE_OLLAMA {
124 Some("http://localhost:11434/v1/chat/completions".into())
125 } else {
126 None
127 }
128 })
129});
130
131pub struct Zeta2FeatureFlag;
132
133impl FeatureFlag for Zeta2FeatureFlag {
134 const NAME: &'static str = "zeta2";
135
136 fn enabled_for_staff() -> bool {
137 false
138 }
139}
140
141#[derive(Clone)]
142struct ZetaGlobal(Entity<Zeta>);
143
144impl Global for ZetaGlobal {}
145
146pub struct Zeta {
147 client: Arc<Client>,
148 user_store: Entity<UserStore>,
149 llm_token: LlmApiToken,
150 _llm_token_subscription: Subscription,
151 projects: HashMap<EntityId, ZetaProject>,
152 options: ZetaOptions,
153 update_required: bool,
154 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
155 #[cfg(feature = "eval-support")]
156 eval_cache: Option<Arc<dyn EvalCache>>,
157 edit_prediction_model: ZetaEditPredictionModel,
158 sweep_api_token: Option<String>,
159 sweep_ai_debug_info: Arc<str>,
160}
161
162#[derive(PartialEq, Eq)]
163pub enum ZetaEditPredictionModel {
164 ZedCloud,
165 Sweep,
166}
167
168#[derive(Debug, Clone, PartialEq)]
169pub struct ZetaOptions {
170 pub context: ContextMode,
171 pub max_prompt_bytes: usize,
172 pub max_diagnostic_bytes: usize,
173 pub prompt_format: predict_edits_v3::PromptFormat,
174 pub file_indexing_parallelism: usize,
175 pub buffer_change_grouping_interval: Duration,
176}
177
178#[derive(Debug, Clone, PartialEq)]
179pub enum ContextMode {
180 Agentic(AgenticContextOptions),
181 Syntax(EditPredictionContextOptions),
182}
183
184#[derive(Debug, Clone, PartialEq)]
185pub struct AgenticContextOptions {
186 pub excerpt: EditPredictionExcerptOptions,
187}
188
189impl ContextMode {
190 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
191 match self {
192 ContextMode::Agentic(options) => &options.excerpt,
193 ContextMode::Syntax(options) => &options.excerpt,
194 }
195 }
196}
197
198#[derive(Debug)]
199pub enum ZetaDebugInfo {
200 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
201 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
202 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
203 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
204 EditPredictionRequested(ZetaEditPredictionDebugInfo),
205}
206
207#[derive(Debug)]
208pub struct ZetaContextRetrievalStartedDebugInfo {
209 pub project: Entity<Project>,
210 pub timestamp: Instant,
211 pub search_prompt: String,
212}
213
214#[derive(Debug)]
215pub struct ZetaContextRetrievalDebugInfo {
216 pub project: Entity<Project>,
217 pub timestamp: Instant,
218}
219
220#[derive(Debug)]
221pub struct ZetaEditPredictionDebugInfo {
222 pub request: predict_edits_v3::PredictEditsRequest,
223 pub retrieval_time: TimeDelta,
224 pub buffer: WeakEntity<Buffer>,
225 pub position: language::Anchor,
226 pub local_prompt: Result<String, String>,
227 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
228}
229
230#[derive(Debug)]
231pub struct ZetaSearchQueryDebugInfo {
232 pub project: Entity<Project>,
233 pub timestamp: Instant,
234 pub search_queries: Vec<SearchToolQuery>,
235}
236
237pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
238
239struct ZetaProject {
240 syntax_index: Option<Entity<SyntaxIndex>>,
241 events: VecDeque<Event>,
242 recent_paths: VecDeque<ProjectPath>,
243 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
244 current_prediction: Option<CurrentEditPrediction>,
245 next_pending_prediction_id: usize,
246 pending_predictions: ArrayVec<PendingPrediction, 2>,
247 last_prediction_refresh: Option<(EntityId, Instant)>,
248 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
249 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
250 refresh_context_debounce_task: Option<Task<Option<()>>>,
251 refresh_context_timestamp: Option<Instant>,
252 _subscription: gpui::Subscription,
253}
254
255#[derive(Debug, Clone)]
256struct CurrentEditPrediction {
257 pub requested_by: PredictionRequestedBy,
258 pub prediction: EditPrediction,
259}
260
261impl CurrentEditPrediction {
262 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
263 let Some(new_edits) = self
264 .prediction
265 .interpolate(&self.prediction.buffer.read(cx))
266 else {
267 return false;
268 };
269
270 if self.prediction.buffer != old_prediction.prediction.buffer {
271 return true;
272 }
273
274 let Some(old_edits) = old_prediction
275 .prediction
276 .interpolate(&old_prediction.prediction.buffer.read(cx))
277 else {
278 return true;
279 };
280
281 let requested_by_buffer_id = self.requested_by.buffer_id();
282
283 // This reduces the occurrence of UI thrash from replacing edits
284 //
285 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
286 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
287 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
288 && old_edits.len() == 1
289 && new_edits.len() == 1
290 {
291 let (old_range, old_text) = &old_edits[0];
292 let (new_range, new_text) = &new_edits[0];
293 new_range == old_range && new_text.starts_with(old_text.as_ref())
294 } else {
295 true
296 }
297 }
298}
299
300#[derive(Debug, Clone)]
301enum PredictionRequestedBy {
302 DiagnosticsUpdate,
303 Buffer(EntityId),
304}
305
306impl PredictionRequestedBy {
307 pub fn buffer_id(&self) -> Option<EntityId> {
308 match self {
309 PredictionRequestedBy::DiagnosticsUpdate => None,
310 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
311 }
312 }
313}
314
315struct PendingPrediction {
316 id: usize,
317 _task: Task<()>,
318}
319
320/// A prediction from the perspective of a buffer.
321#[derive(Debug)]
322enum BufferEditPrediction<'a> {
323 Local { prediction: &'a EditPrediction },
324 Jump { prediction: &'a EditPrediction },
325}
326
327struct RegisteredBuffer {
328 snapshot: BufferSnapshot,
329 _subscriptions: [gpui::Subscription; 2],
330}
331
332#[derive(Clone)]
333pub enum Event {
334 BufferChange {
335 old_snapshot: BufferSnapshot,
336 new_snapshot: BufferSnapshot,
337 end_edit_anchor: Option<Anchor>,
338 timestamp: Instant,
339 },
340}
341
342impl Event {
343 pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
344 match self {
345 Event::BufferChange {
346 old_snapshot,
347 new_snapshot,
348 ..
349 } => {
350 let path = new_snapshot.file().map(|f| f.full_path(cx));
351
352 let old_path = old_snapshot.file().and_then(|f| {
353 let old_path = f.full_path(cx);
354 if Some(&old_path) != path.as_ref() {
355 Some(old_path)
356 } else {
357 None
358 }
359 });
360
361 // TODO [zeta2] move to bg?
362 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
363
364 if path == old_path && diff.is_empty() {
365 None
366 } else {
367 Some(predict_edits_v3::Event::BufferChange {
368 old_path,
369 path,
370 diff,
371 //todo: Actually detect if this edit was predicted or not
372 predicted: false,
373 })
374 }
375 }
376 }
377 }
378
379 pub fn project_path(&self, cx: &App) -> Option<project::ProjectPath> {
380 match self {
381 Event::BufferChange { new_snapshot, .. } => new_snapshot
382 .file()
383 .map(|f| project::ProjectPath::from_file(f.as_ref(), cx)),
384 }
385 }
386}
387
388impl Zeta {
389 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
390 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
391 }
392
393 pub fn global(
394 client: &Arc<Client>,
395 user_store: &Entity<UserStore>,
396 cx: &mut App,
397 ) -> Entity<Self> {
398 cx.try_global::<ZetaGlobal>()
399 .map(|global| global.0.clone())
400 .unwrap_or_else(|| {
401 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
402 cx.set_global(ZetaGlobal(zeta.clone()));
403 zeta
404 })
405 }
406
407 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
408 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
409
410 Self {
411 projects: HashMap::default(),
412 client,
413 user_store,
414 options: DEFAULT_OPTIONS,
415 llm_token: LlmApiToken::default(),
416 _llm_token_subscription: cx.subscribe(
417 &refresh_llm_token_listener,
418 |this, _listener, _event, cx| {
419 let client = this.client.clone();
420 let llm_token = this.llm_token.clone();
421 cx.spawn(async move |_this, _cx| {
422 llm_token.refresh(&client).await?;
423 anyhow::Ok(())
424 })
425 .detach_and_log_err(cx);
426 },
427 ),
428 update_required: false,
429 debug_tx: None,
430 #[cfg(feature = "eval-support")]
431 eval_cache: None,
432 edit_prediction_model: ZetaEditPredictionModel::ZedCloud,
433 sweep_api_token: std::env::var("SWEEP_AI_TOKEN")
434 .context("No SWEEP_AI_TOKEN environment variable set")
435 .log_err(),
436 sweep_ai_debug_info: sweep_ai::debug_info(cx),
437 }
438 }
439
440 pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
441 self.edit_prediction_model = model;
442 }
443
444 pub fn has_sweep_api_token(&self) -> bool {
445 self.sweep_api_token.is_some()
446 }
447
448 #[cfg(feature = "eval-support")]
449 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
450 self.eval_cache = Some(cache);
451 }
452
453 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
454 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
455 self.debug_tx = Some(debug_watch_tx);
456 debug_watch_rx
457 }
458
459 pub fn options(&self) -> &ZetaOptions {
460 &self.options
461 }
462
463 pub fn set_options(&mut self, options: ZetaOptions) {
464 self.options = options;
465 }
466
467 pub fn clear_history(&mut self) {
468 for zeta_project in self.projects.values_mut() {
469 zeta_project.events.clear();
470 }
471 }
472
473 pub fn history_for_project(
474 &self,
475 project: &Entity<Project>,
476 ) -> impl DoubleEndedIterator<Item = &Event> {
477 self.projects
478 .get(&project.entity_id())
479 .map(|project| project.events.iter())
480 .into_iter()
481 .flatten()
482 }
483
484 pub fn context_for_project(
485 &self,
486 project: &Entity<Project>,
487 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
488 self.projects
489 .get(&project.entity_id())
490 .and_then(|project| {
491 Some(
492 project
493 .context
494 .as_ref()?
495 .iter()
496 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
497 )
498 })
499 .into_iter()
500 .flatten()
501 }
502
503 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
504 if self.edit_prediction_model == ZetaEditPredictionModel::ZedCloud {
505 self.user_store.read(cx).edit_prediction_usage()
506 } else {
507 None
508 }
509 }
510
511 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
512 self.get_or_init_zeta_project(project, cx);
513 }
514
515 pub fn register_buffer(
516 &mut self,
517 buffer: &Entity<Buffer>,
518 project: &Entity<Project>,
519 cx: &mut Context<Self>,
520 ) {
521 let zeta_project = self.get_or_init_zeta_project(project, cx);
522 Self::register_buffer_impl(zeta_project, buffer, project, cx);
523 }
524
525 fn get_or_init_zeta_project(
526 &mut self,
527 project: &Entity<Project>,
528 cx: &mut Context<Self>,
529 ) -> &mut ZetaProject {
530 self.projects
531 .entry(project.entity_id())
532 .or_insert_with(|| ZetaProject {
533 syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
534 Some(cx.new(|cx| {
535 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
536 }))
537 } else {
538 None
539 },
540 events: VecDeque::new(),
541 recent_paths: VecDeque::new(),
542 registered_buffers: HashMap::default(),
543 current_prediction: None,
544 pending_predictions: ArrayVec::new(),
545 next_pending_prediction_id: 0,
546 last_prediction_refresh: None,
547 context: None,
548 refresh_context_task: None,
549 refresh_context_debounce_task: None,
550 refresh_context_timestamp: None,
551 _subscription: cx.subscribe(&project, Self::handle_project_event),
552 })
553 }
554
555 fn handle_project_event(
556 &mut self,
557 project: Entity<Project>,
558 event: &project::Event,
559 cx: &mut Context<Self>,
560 ) {
561 // TODO [zeta2] init with recent paths
562 match event {
563 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
564 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
565 return;
566 };
567 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
568 if let Some(path) = path {
569 if let Some(ix) = zeta_project
570 .recent_paths
571 .iter()
572 .position(|probe| probe == &path)
573 {
574 zeta_project.recent_paths.remove(ix);
575 }
576 zeta_project.recent_paths.push_front(path);
577 }
578 }
579 project::Event::DiagnosticsUpdated { .. } => {
580 self.refresh_prediction_from_diagnostics(project, cx);
581 }
582 _ => (),
583 }
584 }
585
586 fn register_buffer_impl<'a>(
587 zeta_project: &'a mut ZetaProject,
588 buffer: &Entity<Buffer>,
589 project: &Entity<Project>,
590 cx: &mut Context<Self>,
591 ) -> &'a mut RegisteredBuffer {
592 let buffer_id = buffer.entity_id();
593 match zeta_project.registered_buffers.entry(buffer_id) {
594 hash_map::Entry::Occupied(entry) => entry.into_mut(),
595 hash_map::Entry::Vacant(entry) => {
596 let snapshot = buffer.read(cx).snapshot();
597 let project_entity_id = project.entity_id();
598 entry.insert(RegisteredBuffer {
599 snapshot,
600 _subscriptions: [
601 cx.subscribe(buffer, {
602 let project = project.downgrade();
603 move |this, buffer, event, cx| {
604 if let language::BufferEvent::Edited = event
605 && let Some(project) = project.upgrade()
606 {
607 this.report_changes_for_buffer(&buffer, &project, cx);
608 }
609 }
610 }),
611 cx.observe_release(buffer, move |this, _buffer, _cx| {
612 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
613 else {
614 return;
615 };
616 zeta_project.registered_buffers.remove(&buffer_id);
617 }),
618 ],
619 })
620 }
621 }
622 }
623
624 fn report_changes_for_buffer(
625 &mut self,
626 buffer: &Entity<Buffer>,
627 project: &Entity<Project>,
628 cx: &mut Context<Self>,
629 ) {
630 let event_count_max = match self.edit_prediction_model {
631 ZetaEditPredictionModel::ZedCloud => EVENT_COUNT_MAX_ZETA,
632 ZetaEditPredictionModel::Sweep => EVENT_COUNT_MAX_SWEEP,
633 };
634
635 let sweep_ai_project = self.get_or_init_zeta_project(project, cx);
636 let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
637
638 let new_snapshot = buffer.read(cx).snapshot();
639 if new_snapshot.version == registered_buffer.snapshot.version {
640 return;
641 }
642
643 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
644 let end_edit_anchor = new_snapshot
645 .anchored_edits_since::<Point>(&old_snapshot.version)
646 .last()
647 .map(|(_, range)| range.end);
648 let events = &mut sweep_ai_project.events;
649
650 if let Some(Event::BufferChange {
651 new_snapshot: last_new_snapshot,
652 end_edit_anchor: last_end_edit_anchor,
653 ..
654 }) = events.back_mut()
655 {
656 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
657 == last_new_snapshot.remote_id()
658 && old_snapshot.version == last_new_snapshot.version;
659
660 let should_coalesce = is_next_snapshot_of_same_buffer
661 && end_edit_anchor
662 .as_ref()
663 .zip(last_end_edit_anchor.as_ref())
664 .is_some_and(|(a, b)| {
665 let a = a.to_point(&new_snapshot);
666 let b = b.to_point(&new_snapshot);
667 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
668 });
669
670 if should_coalesce {
671 *last_end_edit_anchor = end_edit_anchor;
672 *last_new_snapshot = new_snapshot;
673 return;
674 }
675 }
676
677 if events.len() >= event_count_max {
678 events.pop_front();
679 }
680
681 events.push_back(Event::BufferChange {
682 old_snapshot,
683 new_snapshot,
684 end_edit_anchor,
685 timestamp: Instant::now(),
686 });
687 }
688
689 fn current_prediction_for_buffer(
690 &self,
691 buffer: &Entity<Buffer>,
692 project: &Entity<Project>,
693 cx: &App,
694 ) -> Option<BufferEditPrediction<'_>> {
695 let project_state = self.projects.get(&project.entity_id())?;
696
697 let CurrentEditPrediction {
698 requested_by,
699 prediction,
700 } = project_state.current_prediction.as_ref()?;
701
702 if prediction.targets_buffer(buffer.read(cx)) {
703 Some(BufferEditPrediction::Local { prediction })
704 } else {
705 let show_jump = match requested_by {
706 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
707 requested_by_buffer_id == &buffer.entity_id()
708 }
709 PredictionRequestedBy::DiagnosticsUpdate => true,
710 };
711
712 if show_jump {
713 Some(BufferEditPrediction::Jump { prediction })
714 } else {
715 None
716 }
717 }
718 }
719
720 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
721 if self.edit_prediction_model != ZetaEditPredictionModel::ZedCloud {
722 return;
723 }
724
725 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
726 return;
727 };
728
729 let Some(prediction) = project_state.current_prediction.take() else {
730 return;
731 };
732 let request_id = prediction.prediction.id.to_string();
733 project_state.pending_predictions.clear();
734
735 let client = self.client.clone();
736 let llm_token = self.llm_token.clone();
737 let app_version = AppVersion::global(cx);
738 cx.spawn(async move |this, cx| {
739 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
740 http_client::Url::parse(&predict_edits_url)?
741 } else {
742 client
743 .http_client()
744 .build_zed_llm_url("/predict_edits/accept", &[])?
745 };
746
747 let response = cx
748 .background_spawn(Self::send_api_request::<()>(
749 move |builder| {
750 let req = builder.uri(url.as_ref()).body(
751 serde_json::to_string(&AcceptEditPredictionBody {
752 request_id: request_id.clone(),
753 })?
754 .into(),
755 );
756 Ok(req?)
757 },
758 client,
759 llm_token,
760 app_version,
761 ))
762 .await;
763
764 Self::handle_api_response(&this, response, cx)?;
765 anyhow::Ok(())
766 })
767 .detach_and_log_err(cx);
768 }
769
770 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
771 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
772 project_state.current_prediction.take();
773 project_state.pending_predictions.clear();
774 };
775 }
776
777 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
778 self.projects
779 .get(&project.entity_id())
780 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
781 }
782
783 pub fn refresh_prediction_from_buffer(
784 &mut self,
785 project: Entity<Project>,
786 buffer: Entity<Buffer>,
787 position: language::Anchor,
788 cx: &mut Context<Self>,
789 ) {
790 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
791 let Some(request_task) = this
792 .update(cx, |this, cx| {
793 this.request_prediction(&project, &buffer, position, cx)
794 })
795 .log_err()
796 else {
797 return Task::ready(anyhow::Ok(()));
798 };
799
800 let project = project.clone();
801 cx.spawn(async move |cx| {
802 if let Some(prediction) = request_task.await? {
803 this.update(cx, |this, cx| {
804 let project_state = this
805 .projects
806 .get_mut(&project.entity_id())
807 .context("Project not found")?;
808
809 let new_prediction = CurrentEditPrediction {
810 requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()),
811 prediction: prediction,
812 };
813
814 if project_state
815 .current_prediction
816 .as_ref()
817 .is_none_or(|old_prediction| {
818 new_prediction.should_replace_prediction(&old_prediction, cx)
819 })
820 {
821 project_state.current_prediction = Some(new_prediction);
822 cx.notify();
823 }
824 anyhow::Ok(())
825 })??;
826 }
827 Ok(())
828 })
829 })
830 }
831
832 pub fn refresh_prediction_from_diagnostics(
833 &mut self,
834 project: Entity<Project>,
835 cx: &mut Context<Self>,
836 ) {
837 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
838 return;
839 };
840
841 // Prefer predictions from buffer
842 if zeta_project.current_prediction.is_some() {
843 return;
844 };
845
846 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
847 let Some(open_buffer_task) = project
848 .update(cx, |project, cx| {
849 project
850 .active_entry()
851 .and_then(|entry| project.path_for_entry(entry, cx))
852 .map(|path| project.open_buffer(path, cx))
853 })
854 .log_err()
855 .flatten()
856 else {
857 return Task::ready(anyhow::Ok(()));
858 };
859
860 cx.spawn(async move |cx| {
861 let active_buffer = open_buffer_task.await?;
862 let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
863
864 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
865 active_buffer,
866 &snapshot,
867 Default::default(),
868 Default::default(),
869 &project,
870 cx,
871 )
872 .await?
873 else {
874 return anyhow::Ok(());
875 };
876
877 let Some(prediction) = this
878 .update(cx, |this, cx| {
879 this.request_prediction(&project, &jump_buffer, jump_position, cx)
880 })?
881 .await?
882 else {
883 return anyhow::Ok(());
884 };
885
886 this.update(cx, |this, cx| {
887 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
888 zeta_project.current_prediction.get_or_insert_with(|| {
889 cx.notify();
890 CurrentEditPrediction {
891 requested_by: PredictionRequestedBy::DiagnosticsUpdate,
892 prediction,
893 }
894 });
895 }
896 })?;
897
898 anyhow::Ok(())
899 })
900 });
901 }
902
903 #[cfg(not(test))]
904 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
905 #[cfg(test)]
906 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
907
908 fn queue_prediction_refresh(
909 &mut self,
910 project: Entity<Project>,
911 throttle_entity: EntityId,
912 cx: &mut Context<Self>,
913 do_refresh: impl FnOnce(WeakEntity<Self>, &mut AsyncApp) -> Task<Result<()>> + 'static,
914 ) {
915 let zeta_project = self.get_or_init_zeta_project(&project, cx);
916 let pending_prediction_id = zeta_project.next_pending_prediction_id;
917 zeta_project.next_pending_prediction_id += 1;
918 let last_request = zeta_project.last_prediction_refresh;
919
920 // TODO report cancelled requests like in zeta1
921 let task = cx.spawn(async move |this, cx| {
922 if let Some((last_entity, last_timestamp)) = last_request
923 && throttle_entity == last_entity
924 && let Some(timeout) =
925 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
926 {
927 cx.background_executor().timer(timeout).await;
928 }
929
930 do_refresh(this.clone(), cx).await.log_err();
931
932 this.update(cx, |this, cx| {
933 let zeta_project = this.get_or_init_zeta_project(&project, cx);
934
935 if zeta_project.pending_predictions[0].id == pending_prediction_id {
936 zeta_project.pending_predictions.remove(0);
937 } else {
938 zeta_project.pending_predictions.clear();
939 }
940
941 cx.notify();
942 })
943 .ok();
944 });
945
946 if zeta_project.pending_predictions.len() <= 1 {
947 zeta_project.pending_predictions.push(PendingPrediction {
948 id: pending_prediction_id,
949 _task: task,
950 });
951 } else if zeta_project.pending_predictions.len() == 2 {
952 zeta_project.pending_predictions.pop();
953 zeta_project.pending_predictions.push(PendingPrediction {
954 id: pending_prediction_id,
955 _task: task,
956 });
957 }
958 }
959
960 pub fn request_prediction(
961 &mut self,
962 project: &Entity<Project>,
963 active_buffer: &Entity<Buffer>,
964 position: language::Anchor,
965 cx: &mut Context<Self>,
966 ) -> Task<Result<Option<EditPrediction>>> {
967 match self.edit_prediction_model {
968 ZetaEditPredictionModel::ZedCloud => {
969 self.request_prediction_with_zed_cloud(project, active_buffer, position, cx)
970 }
971 ZetaEditPredictionModel::Sweep => {
972 self.request_prediction_with_sweep(project, active_buffer, position, true, cx)
973 }
974 }
975 }
976
977 fn request_prediction_with_sweep(
978 &mut self,
979 project: &Entity<Project>,
980 active_buffer: &Entity<Buffer>,
981 position: language::Anchor,
982 allow_jump: bool,
983 cx: &mut Context<Self>,
984 ) -> Task<Result<Option<EditPrediction>>> {
985 let snapshot = active_buffer.read(cx).snapshot();
986 let debug_info = self.sweep_ai_debug_info.clone();
987 let Some(api_token) = self.sweep_api_token.clone() else {
988 return Task::ready(Ok(None));
989 };
990 let full_path: Arc<Path> = snapshot
991 .file()
992 .map(|file| file.full_path(cx))
993 .unwrap_or_else(|| "untitled".into())
994 .into();
995
996 let project_file = project::File::from_dyn(snapshot.file());
997 let repo_name = project_file
998 .map(|file| file.worktree.read(cx).root_name_str())
999 .unwrap_or("untitled")
1000 .into();
1001 let offset = position.to_offset(&snapshot);
1002
1003 let project_state = self.get_or_init_zeta_project(project, cx);
1004 let events = project_state.events.clone();
1005 let has_events = !events.is_empty();
1006 let recent_buffers = project_state.recent_paths.iter().cloned();
1007 let http_client = cx.http_client();
1008
1009 let recent_buffer_snapshots = recent_buffers
1010 .filter_map(|project_path| {
1011 let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
1012 if active_buffer == &buffer {
1013 None
1014 } else {
1015 Some(buffer.read(cx).snapshot())
1016 }
1017 })
1018 .take(3)
1019 .collect::<Vec<_>>();
1020
1021 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1022
1023 let cursor_point = position.to_point(&snapshot);
1024 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1025 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1026 let diagnostic_search_range =
1027 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1028
1029 let result = cx.background_spawn({
1030 let snapshot = snapshot.clone();
1031 let diagnostic_search_range = diagnostic_search_range.clone();
1032 async move {
1033 let text = snapshot.text();
1034
1035 let mut recent_changes = String::new();
1036 for event in events {
1037 sweep_ai::write_event(event, &mut recent_changes).unwrap();
1038 }
1039
1040 let mut file_chunks = recent_buffer_snapshots
1041 .into_iter()
1042 .map(|snapshot| {
1043 let end_point = Point::new(30, 0).min(snapshot.max_point());
1044 sweep_ai::FileChunk {
1045 content: snapshot.text_for_range(Point::zero()..end_point).collect(),
1046 file_path: snapshot
1047 .file()
1048 .map(|f| f.path().as_unix_str())
1049 .unwrap_or("untitled")
1050 .to_string(),
1051 start_line: 0,
1052 end_line: end_point.row as usize,
1053 timestamp: snapshot.file().and_then(|file| {
1054 Some(
1055 file.disk_state()
1056 .mtime()?
1057 .to_seconds_and_nanos_for_persistence()?
1058 .0,
1059 )
1060 }),
1061 }
1062 })
1063 .collect::<Vec<_>>();
1064
1065 let diagnostic_entries =
1066 snapshot.diagnostics_in_range(diagnostic_search_range, false);
1067 let mut diagnostic_content = String::new();
1068 let mut diagnostic_count = 0;
1069
1070 for entry in diagnostic_entries {
1071 let start_point: Point = entry.range.start;
1072
1073 let severity = match entry.diagnostic.severity {
1074 DiagnosticSeverity::ERROR => "error",
1075 DiagnosticSeverity::WARNING => "warning",
1076 DiagnosticSeverity::INFORMATION => "info",
1077 DiagnosticSeverity::HINT => "hint",
1078 _ => continue,
1079 };
1080
1081 diagnostic_count += 1;
1082
1083 writeln!(
1084 &mut diagnostic_content,
1085 "{} at line {}: {}",
1086 severity,
1087 start_point.row + 1,
1088 entry.diagnostic.message
1089 )?;
1090 }
1091
1092 if !diagnostic_content.is_empty() {
1093 file_chunks.push(sweep_ai::FileChunk {
1094 file_path: format!("Diagnostics for {}", full_path.display()),
1095 start_line: 0,
1096 end_line: diagnostic_count,
1097 content: diagnostic_content,
1098 timestamp: None,
1099 });
1100 }
1101
1102 let request_body = sweep_ai::AutocompleteRequest {
1103 debug_info,
1104 repo_name,
1105 file_path: full_path.clone(),
1106 file_contents: text.clone(),
1107 original_file_contents: text,
1108 cursor_position: offset,
1109 recent_changes: recent_changes.clone(),
1110 changes_above_cursor: true,
1111 multiple_suggestions: false,
1112 branch: None,
1113 file_chunks,
1114 retrieval_chunks: vec![],
1115 recent_user_actions: vec![],
1116 // TODO
1117 privacy_mode_enabled: false,
1118 };
1119
1120 let mut buf: Vec<u8> = Vec::new();
1121 let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
1122 serde_json::to_writer(writer, &request_body)?;
1123 let body: AsyncBody = buf.into();
1124
1125 const SWEEP_API_URL: &str =
1126 "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
1127
1128 let request = http_client::Request::builder()
1129 .uri(SWEEP_API_URL)
1130 .header("Content-Type", "application/json")
1131 .header("Authorization", format!("Bearer {}", api_token))
1132 .header("Connection", "keep-alive")
1133 .header("Content-Encoding", "br")
1134 .method(Method::POST)
1135 .body(body)?;
1136
1137 let mut response = http_client.send(request).await?;
1138
1139 let mut body: Vec<u8> = Vec::new();
1140 response.body_mut().read_to_end(&mut body).await?;
1141
1142 if !response.status().is_success() {
1143 anyhow::bail!(
1144 "Request failed with status: {:?}\nBody: {}",
1145 response.status(),
1146 String::from_utf8_lossy(&body),
1147 );
1148 };
1149
1150 let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
1151
1152 let old_text = snapshot
1153 .text_for_range(response.start_index..response.end_index)
1154 .collect::<String>();
1155 let edits = language::text_diff(&old_text, &response.completion)
1156 .into_iter()
1157 .map(|(range, text)| {
1158 (
1159 snapshot.anchor_after(response.start_index + range.start)
1160 ..snapshot.anchor_before(response.start_index + range.end),
1161 text,
1162 )
1163 })
1164 .collect::<Vec<_>>();
1165
1166 anyhow::Ok((response.autocomplete_id, edits, snapshot))
1167 }
1168 });
1169
1170 let buffer = active_buffer.clone();
1171 let project = project.clone();
1172 let active_buffer = active_buffer.clone();
1173
1174 cx.spawn(async move |this, cx| {
1175 let (id, edits, old_snapshot) = result.await?;
1176
1177 if edits.is_empty() {
1178 if has_events
1179 && allow_jump
1180 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1181 active_buffer,
1182 &snapshot,
1183 diagnostic_search_range,
1184 cursor_point,
1185 &project,
1186 cx,
1187 )
1188 .await?
1189 {
1190 return this
1191 .update(cx, |this, cx| {
1192 this.request_prediction_with_sweep(
1193 &project,
1194 &jump_buffer,
1195 jump_position,
1196 false,
1197 cx,
1198 )
1199 })?
1200 .await;
1201 }
1202
1203 return anyhow::Ok(None);
1204 }
1205
1206 let Some((edits, new_snapshot, preview_task)) =
1207 buffer.read_with(cx, |buffer, cx| {
1208 let new_snapshot = buffer.snapshot();
1209
1210 let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
1211 edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)?
1212 .into();
1213 let preview_task = buffer.preview_edits(edits.clone(), cx);
1214
1215 Some((edits, new_snapshot, preview_task))
1216 })?
1217 else {
1218 return anyhow::Ok(None);
1219 };
1220
1221 let prediction = EditPrediction {
1222 id: EditPredictionId(id.into()),
1223 edits,
1224 snapshot: new_snapshot,
1225 edit_preview: preview_task.await,
1226 buffer,
1227 };
1228
1229 anyhow::Ok(Some(prediction))
1230 })
1231 }
1232
1233 async fn next_diagnostic_location(
1234 active_buffer: Entity<Buffer>,
1235 active_buffer_snapshot: &BufferSnapshot,
1236 active_buffer_diagnostic_search_range: Range<Point>,
1237 active_buffer_cursor_point: Point,
1238 project: &Entity<Project>,
1239 cx: &mut AsyncApp,
1240 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1241 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1242 let mut jump_location = active_buffer_snapshot
1243 .diagnostic_groups(None)
1244 .into_iter()
1245 .filter_map(|(_, group)| {
1246 let range = &group.entries[group.primary_ix]
1247 .range
1248 .to_point(&active_buffer_snapshot);
1249 if range.overlaps(&active_buffer_diagnostic_search_range) {
1250 None
1251 } else {
1252 Some(range.start)
1253 }
1254 })
1255 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1256 .map(|position| {
1257 (
1258 active_buffer.clone(),
1259 active_buffer_snapshot.anchor_before(position),
1260 )
1261 });
1262
1263 if jump_location.is_none() {
1264 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1265 let file = buffer.file()?;
1266
1267 Some(ProjectPath {
1268 worktree_id: file.worktree_id(cx),
1269 path: file.path().clone(),
1270 })
1271 })?;
1272
1273 let buffer_task = project.update(cx, |project, cx| {
1274 let (path, _, _) = project
1275 .diagnostic_summaries(false, cx)
1276 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1277 .max_by_key(|(path, _, _)| {
1278 // find the buffer with errors that shares most parent directories
1279 path.path
1280 .components()
1281 .zip(
1282 active_buffer_path
1283 .as_ref()
1284 .map(|p| p.path.components())
1285 .unwrap_or_default(),
1286 )
1287 .take_while(|(a, b)| a == b)
1288 .count()
1289 })?;
1290
1291 Some(project.open_buffer(path, cx))
1292 })?;
1293
1294 if let Some(buffer_task) = buffer_task {
1295 let closest_buffer = buffer_task.await?;
1296
1297 jump_location = closest_buffer
1298 .read_with(cx, |buffer, _cx| {
1299 buffer
1300 .buffer_diagnostics(None)
1301 .into_iter()
1302 .min_by_key(|entry| entry.diagnostic.severity)
1303 .map(|entry| entry.range.start)
1304 })?
1305 .map(|position| (closest_buffer, position));
1306 }
1307 }
1308
1309 anyhow::Ok(jump_location)
1310 }
1311
1312 fn request_prediction_with_zed_cloud(
1313 &mut self,
1314 project: &Entity<Project>,
1315 active_buffer: &Entity<Buffer>,
1316 position: language::Anchor,
1317 cx: &mut Context<Self>,
1318 ) -> Task<Result<Option<EditPrediction>>> {
1319 let project_state = self.projects.get(&project.entity_id());
1320
1321 let index_state = project_state.and_then(|state| {
1322 state
1323 .syntax_index
1324 .as_ref()
1325 .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1326 });
1327 let options = self.options.clone();
1328 let active_snapshot = active_buffer.read(cx).snapshot();
1329 let Some(excerpt_path) = active_snapshot
1330 .file()
1331 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1332 else {
1333 return Task::ready(Err(anyhow!("No file path for excerpt")));
1334 };
1335 let client = self.client.clone();
1336 let llm_token = self.llm_token.clone();
1337 let app_version = AppVersion::global(cx);
1338 let worktree_snapshots = project
1339 .read(cx)
1340 .worktrees(cx)
1341 .map(|worktree| worktree.read(cx).snapshot())
1342 .collect::<Vec<_>>();
1343 let debug_tx = self.debug_tx.clone();
1344
1345 let events = project_state
1346 .map(|state| {
1347 state
1348 .events
1349 .iter()
1350 .filter_map(|event| event.to_request_event(cx))
1351 .collect::<Vec<_>>()
1352 })
1353 .unwrap_or_default();
1354
1355 let diagnostics = active_snapshot.diagnostic_sets().clone();
1356
1357 let parent_abs_path =
1358 project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
1359 let mut path = f.worktree.read(cx).absolutize(&f.path);
1360 if path.pop() { Some(path) } else { None }
1361 });
1362
1363 // TODO data collection
1364 let can_collect_data = cx.is_staff();
1365
1366 let empty_context_files = HashMap::default();
1367 let context_files = project_state
1368 .and_then(|project_state| project_state.context.as_ref())
1369 .unwrap_or(&empty_context_files);
1370
1371 #[cfg(feature = "eval-support")]
1372 let parsed_fut = futures::future::join_all(
1373 context_files
1374 .keys()
1375 .map(|buffer| buffer.read(cx).parsing_idle()),
1376 );
1377
1378 let mut included_files = context_files
1379 .iter()
1380 .filter_map(|(buffer_entity, ranges)| {
1381 let buffer = buffer_entity.read(cx);
1382 Some((
1383 buffer_entity.clone(),
1384 buffer.snapshot(),
1385 buffer.file()?.full_path(cx).into(),
1386 ranges.clone(),
1387 ))
1388 })
1389 .collect::<Vec<_>>();
1390
1391 included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1392 (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1393 });
1394
1395 #[cfg(feature = "eval-support")]
1396 let eval_cache = self.eval_cache.clone();
1397
1398 let request_task = cx.background_spawn({
1399 let active_buffer = active_buffer.clone();
1400 async move {
1401 #[cfg(feature = "eval-support")]
1402 parsed_fut.await;
1403
1404 let index_state = if let Some(index_state) = index_state {
1405 Some(index_state.lock_owned().await)
1406 } else {
1407 None
1408 };
1409
1410 let cursor_offset = position.to_offset(&active_snapshot);
1411 let cursor_point = cursor_offset.to_point(&active_snapshot);
1412
1413 let before_retrieval = chrono::Utc::now();
1414
1415 let (diagnostic_groups, diagnostic_groups_truncated) =
1416 Self::gather_nearby_diagnostics(
1417 cursor_offset,
1418 &diagnostics,
1419 &active_snapshot,
1420 options.max_diagnostic_bytes,
1421 );
1422
1423 let cloud_request = match options.context {
1424 ContextMode::Agentic(context_options) => {
1425 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1426 cursor_point,
1427 &active_snapshot,
1428 &context_options.excerpt,
1429 index_state.as_deref(),
1430 ) else {
1431 return Ok((None, None));
1432 };
1433
1434 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1435 ..active_snapshot.anchor_before(excerpt.range.end);
1436
1437 if let Some(buffer_ix) =
1438 included_files.iter().position(|(_, snapshot, _, _)| {
1439 snapshot.remote_id() == active_snapshot.remote_id()
1440 })
1441 {
1442 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1443 ranges.push(excerpt_anchor_range);
1444 retrieval_search::merge_anchor_ranges(ranges, buffer);
1445 let last_ix = included_files.len() - 1;
1446 included_files.swap(buffer_ix, last_ix);
1447 } else {
1448 included_files.push((
1449 active_buffer.clone(),
1450 active_snapshot.clone(),
1451 excerpt_path.clone(),
1452 vec![excerpt_anchor_range],
1453 ));
1454 }
1455
1456 let included_files = included_files
1457 .iter()
1458 .map(|(_, snapshot, path, ranges)| {
1459 let ranges = ranges
1460 .iter()
1461 .map(|range| {
1462 let point_range = range.to_point(&snapshot);
1463 Line(point_range.start.row)..Line(point_range.end.row)
1464 })
1465 .collect::<Vec<_>>();
1466 let excerpts = assemble_excerpts(&snapshot, ranges);
1467 predict_edits_v3::IncludedFile {
1468 path: path.clone(),
1469 max_row: Line(snapshot.max_point().row),
1470 excerpts,
1471 }
1472 })
1473 .collect::<Vec<_>>();
1474
1475 predict_edits_v3::PredictEditsRequest {
1476 excerpt_path,
1477 excerpt: String::new(),
1478 excerpt_line_range: Line(0)..Line(0),
1479 excerpt_range: 0..0,
1480 cursor_point: predict_edits_v3::Point {
1481 line: predict_edits_v3::Line(cursor_point.row),
1482 column: cursor_point.column,
1483 },
1484 included_files,
1485 referenced_declarations: vec![],
1486 events,
1487 can_collect_data,
1488 diagnostic_groups,
1489 diagnostic_groups_truncated,
1490 debug_info: debug_tx.is_some(),
1491 prompt_max_bytes: Some(options.max_prompt_bytes),
1492 prompt_format: options.prompt_format,
1493 // TODO [zeta2]
1494 signatures: vec![],
1495 excerpt_parent: None,
1496 git_info: None,
1497 }
1498 }
1499 ContextMode::Syntax(context_options) => {
1500 let Some(context) = EditPredictionContext::gather_context(
1501 cursor_point,
1502 &active_snapshot,
1503 parent_abs_path.as_deref(),
1504 &context_options,
1505 index_state.as_deref(),
1506 ) else {
1507 return Ok((None, None));
1508 };
1509
1510 make_syntax_context_cloud_request(
1511 excerpt_path,
1512 context,
1513 events,
1514 can_collect_data,
1515 diagnostic_groups,
1516 diagnostic_groups_truncated,
1517 None,
1518 debug_tx.is_some(),
1519 &worktree_snapshots,
1520 index_state.as_deref(),
1521 Some(options.max_prompt_bytes),
1522 options.prompt_format,
1523 )
1524 }
1525 };
1526
1527 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1528
1529 let retrieval_time = chrono::Utc::now() - before_retrieval;
1530
1531 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1532 let (response_tx, response_rx) = oneshot::channel();
1533
1534 debug_tx
1535 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1536 ZetaEditPredictionDebugInfo {
1537 request: cloud_request.clone(),
1538 retrieval_time,
1539 buffer: active_buffer.downgrade(),
1540 local_prompt: match prompt_result.as_ref() {
1541 Ok((prompt, _)) => Ok(prompt.clone()),
1542 Err(err) => Err(err.to_string()),
1543 },
1544 position,
1545 response_rx,
1546 },
1547 ))
1548 .ok();
1549 Some(response_tx)
1550 } else {
1551 None
1552 };
1553
1554 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1555 if let Some(debug_response_tx) = debug_response_tx {
1556 debug_response_tx
1557 .send((Err("Request skipped".to_string()), TimeDelta::zero()))
1558 .ok();
1559 }
1560 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1561 }
1562
1563 let (prompt, _) = prompt_result?;
1564 let request = open_ai::Request {
1565 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1566 messages: vec![open_ai::RequestMessage::User {
1567 content: open_ai::MessageContent::Plain(prompt),
1568 }],
1569 stream: false,
1570 max_completion_tokens: None,
1571 stop: Default::default(),
1572 temperature: 0.7,
1573 tool_choice: None,
1574 parallel_tool_calls: None,
1575 tools: vec![],
1576 prompt_cache_key: None,
1577 reasoning_effort: None,
1578 };
1579
1580 log::trace!("Sending edit prediction request");
1581
1582 let before_request = chrono::Utc::now();
1583 let response = Self::send_raw_llm_request(
1584 request,
1585 client,
1586 llm_token,
1587 app_version,
1588 #[cfg(feature = "eval-support")]
1589 eval_cache,
1590 #[cfg(feature = "eval-support")]
1591 EvalCacheEntryKind::Prediction,
1592 )
1593 .await;
1594 let request_time = chrono::Utc::now() - before_request;
1595
1596 log::trace!("Got edit prediction response");
1597
1598 if let Some(debug_response_tx) = debug_response_tx {
1599 debug_response_tx
1600 .send((
1601 response
1602 .as_ref()
1603 .map_err(|err| err.to_string())
1604 .map(|response| response.0.clone()),
1605 request_time,
1606 ))
1607 .ok();
1608 }
1609
1610 let (res, usage) = response?;
1611 let request_id = EditPredictionId(res.id.clone().into());
1612 let Some(mut output_text) = text_from_response(res) else {
1613 return Ok((None, usage));
1614 };
1615
1616 if output_text.contains(CURSOR_MARKER) {
1617 log::trace!("Stripping out {CURSOR_MARKER} from response");
1618 output_text = output_text.replace(CURSOR_MARKER, "");
1619 }
1620
1621 let get_buffer_from_context = |path: &Path| {
1622 included_files
1623 .iter()
1624 .find_map(|(_, buffer, probe_path, ranges)| {
1625 if probe_path.as_ref() == path {
1626 Some((buffer, ranges.as_slice()))
1627 } else {
1628 None
1629 }
1630 })
1631 };
1632
1633 let (edited_buffer_snapshot, edits) = match options.prompt_format {
1634 PromptFormat::NumLinesUniDiff => {
1635 // TODO: Implement parsing of multi-file diffs
1636 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1637 }
1638 PromptFormat::Minimal | PromptFormat::MinimalQwen => {
1639 if output_text.contains("--- a/\n+++ b/\nNo edits") {
1640 let edits = vec![];
1641 (&active_snapshot, edits)
1642 } else {
1643 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1644 }
1645 }
1646 PromptFormat::OldTextNewText => {
1647 crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1648 .await?
1649 }
1650 _ => {
1651 bail!("unsupported prompt format {}", options.prompt_format)
1652 }
1653 };
1654
1655 let edited_buffer = included_files
1656 .iter()
1657 .find_map(|(buffer, snapshot, _, _)| {
1658 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1659 Some(buffer.clone())
1660 } else {
1661 None
1662 }
1663 })
1664 .context("Failed to find buffer in included_buffers")?;
1665
1666 anyhow::Ok((
1667 Some((
1668 request_id,
1669 edited_buffer,
1670 edited_buffer_snapshot.clone(),
1671 edits,
1672 )),
1673 usage,
1674 ))
1675 }
1676 });
1677
1678 cx.spawn({
1679 async move |this, cx| {
1680 let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
1681 Self::handle_api_response(&this, request_task.await, cx)?
1682 else {
1683 return Ok(None);
1684 };
1685
1686 // TODO telemetry: duration, etc
1687 Ok(
1688 EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
1689 .await,
1690 )
1691 }
1692 })
1693 }
1694
1695 async fn send_raw_llm_request(
1696 request: open_ai::Request,
1697 client: Arc<Client>,
1698 llm_token: LlmApiToken,
1699 app_version: SemanticVersion,
1700 #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1701 #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1702 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1703 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1704 http_client::Url::parse(&predict_edits_url)?
1705 } else {
1706 client
1707 .http_client()
1708 .build_zed_llm_url("/predict_edits/raw", &[])?
1709 };
1710
1711 #[cfg(feature = "eval-support")]
1712 let cache_key = if let Some(cache) = eval_cache {
1713 use collections::FxHasher;
1714 use std::hash::{Hash, Hasher};
1715
1716 let mut hasher = FxHasher::default();
1717 url.hash(&mut hasher);
1718 let request_str = serde_json::to_string_pretty(&request)?;
1719 request_str.hash(&mut hasher);
1720 let hash = hasher.finish();
1721
1722 let key = (eval_cache_kind, hash);
1723 if let Some(response_str) = cache.read(key) {
1724 return Ok((serde_json::from_str(&response_str)?, None));
1725 }
1726
1727 Some((cache, request_str, key))
1728 } else {
1729 None
1730 };
1731
1732 let (response, usage) = Self::send_api_request(
1733 |builder| {
1734 let req = builder
1735 .uri(url.as_ref())
1736 .body(serde_json::to_string(&request)?.into());
1737 Ok(req?)
1738 },
1739 client,
1740 llm_token,
1741 app_version,
1742 )
1743 .await?;
1744
1745 #[cfg(feature = "eval-support")]
1746 if let Some((cache, request, key)) = cache_key {
1747 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1748 }
1749
1750 Ok((response, usage))
1751 }
1752
1753 fn handle_api_response<T>(
1754 this: &WeakEntity<Self>,
1755 response: Result<(T, Option<EditPredictionUsage>)>,
1756 cx: &mut gpui::AsyncApp,
1757 ) -> Result<T> {
1758 match response {
1759 Ok((data, usage)) => {
1760 if let Some(usage) = usage {
1761 this.update(cx, |this, cx| {
1762 this.user_store.update(cx, |user_store, cx| {
1763 user_store.update_edit_prediction_usage(usage, cx);
1764 });
1765 })
1766 .ok();
1767 }
1768 Ok(data)
1769 }
1770 Err(err) => {
1771 if err.is::<ZedUpdateRequiredError>() {
1772 cx.update(|cx| {
1773 this.update(cx, |this, _cx| {
1774 this.update_required = true;
1775 })
1776 .ok();
1777
1778 let error_message: SharedString = err.to_string().into();
1779 show_app_notification(
1780 NotificationId::unique::<ZedUpdateRequiredError>(),
1781 cx,
1782 move |cx| {
1783 cx.new(|cx| {
1784 ErrorMessagePrompt::new(error_message.clone(), cx)
1785 .with_link_button("Update Zed", "https://zed.dev/releases")
1786 })
1787 },
1788 );
1789 })
1790 .ok();
1791 }
1792 Err(err)
1793 }
1794 }
1795 }
1796
1797 async fn send_api_request<Res>(
1798 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1799 client: Arc<Client>,
1800 llm_token: LlmApiToken,
1801 app_version: SemanticVersion,
1802 ) -> Result<(Res, Option<EditPredictionUsage>)>
1803 where
1804 Res: DeserializeOwned,
1805 {
1806 let http_client = client.http_client();
1807 let mut token = llm_token.acquire(&client).await?;
1808 let mut did_retry = false;
1809
1810 loop {
1811 let request_builder = http_client::Request::builder().method(Method::POST);
1812
1813 let request = build(
1814 request_builder
1815 .header("Content-Type", "application/json")
1816 .header("Authorization", format!("Bearer {}", token))
1817 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1818 )?;
1819
1820 let mut response = http_client.send(request).await?;
1821
1822 if let Some(minimum_required_version) = response
1823 .headers()
1824 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1825 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1826 {
1827 anyhow::ensure!(
1828 app_version >= minimum_required_version,
1829 ZedUpdateRequiredError {
1830 minimum_version: minimum_required_version
1831 }
1832 );
1833 }
1834
1835 if response.status().is_success() {
1836 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1837
1838 let mut body = Vec::new();
1839 response.body_mut().read_to_end(&mut body).await?;
1840 return Ok((serde_json::from_slice(&body)?, usage));
1841 } else if !did_retry
1842 && response
1843 .headers()
1844 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1845 .is_some()
1846 {
1847 did_retry = true;
1848 token = llm_token.refresh(&client).await?;
1849 } else {
1850 let mut body = String::new();
1851 response.body_mut().read_to_string(&mut body).await?;
1852 anyhow::bail!(
1853 "Request failed with status: {:?}\nBody: {}",
1854 response.status(),
1855 body
1856 );
1857 }
1858 }
1859 }
1860
1861 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1862 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1863
1864 // Refresh the related excerpts when the user just beguns editing after
1865 // an idle period, and after they pause editing.
1866 fn refresh_context_if_needed(
1867 &mut self,
1868 project: &Entity<Project>,
1869 buffer: &Entity<language::Buffer>,
1870 cursor_position: language::Anchor,
1871 cx: &mut Context<Self>,
1872 ) {
1873 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1874 return;
1875 }
1876
1877 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1878 return;
1879 };
1880
1881 let now = Instant::now();
1882 let was_idle = zeta_project
1883 .refresh_context_timestamp
1884 .map_or(true, |timestamp| {
1885 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1886 });
1887 zeta_project.refresh_context_timestamp = Some(now);
1888 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1889 let buffer = buffer.clone();
1890 let project = project.clone();
1891 async move |this, cx| {
1892 if was_idle {
1893 log::debug!("refetching edit prediction context after idle");
1894 } else {
1895 cx.background_executor()
1896 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1897 .await;
1898 log::debug!("refetching edit prediction context after pause");
1899 }
1900 this.update(cx, |this, cx| {
1901 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1902
1903 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1904 zeta_project.refresh_context_task = Some(task.log_err());
1905 };
1906 })
1907 .ok()
1908 }
1909 }));
1910 }
1911
1912 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1913 // and avoid spawning more than one concurrent task.
1914 pub fn refresh_context(
1915 &mut self,
1916 project: Entity<Project>,
1917 buffer: Entity<language::Buffer>,
1918 cursor_position: language::Anchor,
1919 cx: &mut Context<Self>,
1920 ) -> Task<Result<()>> {
1921 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1922 return Task::ready(anyhow::Ok(()));
1923 };
1924
1925 let ContextMode::Agentic(options) = &self.options().context else {
1926 return Task::ready(anyhow::Ok(()));
1927 };
1928
1929 let snapshot = buffer.read(cx).snapshot();
1930 let cursor_point = cursor_position.to_point(&snapshot);
1931 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1932 cursor_point,
1933 &snapshot,
1934 &options.excerpt,
1935 None,
1936 ) else {
1937 return Task::ready(Ok(()));
1938 };
1939
1940 let app_version = AppVersion::global(cx);
1941 let client = self.client.clone();
1942 let llm_token = self.llm_token.clone();
1943 let debug_tx = self.debug_tx.clone();
1944 let current_file_path: Arc<Path> = snapshot
1945 .file()
1946 .map(|f| f.full_path(cx).into())
1947 .unwrap_or_else(|| Path::new("untitled").into());
1948
1949 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1950 predict_edits_v3::PlanContextRetrievalRequest {
1951 excerpt: cursor_excerpt.text(&snapshot).body,
1952 excerpt_path: current_file_path,
1953 excerpt_line_range: cursor_excerpt.line_range,
1954 cursor_file_max_row: Line(snapshot.max_point().row),
1955 events: zeta_project
1956 .events
1957 .iter()
1958 .filter_map(|ev| ev.to_request_event(cx))
1959 .collect(),
1960 },
1961 ) {
1962 Ok(prompt) => prompt,
1963 Err(err) => {
1964 return Task::ready(Err(err));
1965 }
1966 };
1967
1968 if let Some(debug_tx) = &debug_tx {
1969 debug_tx
1970 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1971 ZetaContextRetrievalStartedDebugInfo {
1972 project: project.clone(),
1973 timestamp: Instant::now(),
1974 search_prompt: prompt.clone(),
1975 },
1976 ))
1977 .ok();
1978 }
1979
1980 pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1981 let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1982 language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1983 );
1984
1985 let description = schema
1986 .get("description")
1987 .and_then(|description| description.as_str())
1988 .unwrap()
1989 .to_string();
1990
1991 (schema.into(), description)
1992 });
1993
1994 let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1995
1996 let request = open_ai::Request {
1997 model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
1998 messages: vec![open_ai::RequestMessage::User {
1999 content: open_ai::MessageContent::Plain(prompt),
2000 }],
2001 stream: false,
2002 max_completion_tokens: None,
2003 stop: Default::default(),
2004 temperature: 0.7,
2005 tool_choice: None,
2006 parallel_tool_calls: None,
2007 tools: vec![open_ai::ToolDefinition::Function {
2008 function: FunctionDefinition {
2009 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2010 description: Some(tool_description),
2011 parameters: Some(tool_schema),
2012 },
2013 }],
2014 prompt_cache_key: None,
2015 reasoning_effort: None,
2016 };
2017
2018 #[cfg(feature = "eval-support")]
2019 let eval_cache = self.eval_cache.clone();
2020
2021 cx.spawn(async move |this, cx| {
2022 log::trace!("Sending search planning request");
2023 let response = Self::send_raw_llm_request(
2024 request,
2025 client,
2026 llm_token,
2027 app_version,
2028 #[cfg(feature = "eval-support")]
2029 eval_cache.clone(),
2030 #[cfg(feature = "eval-support")]
2031 EvalCacheEntryKind::Context,
2032 )
2033 .await;
2034 let mut response = Self::handle_api_response(&this, response, cx)?;
2035 log::trace!("Got search planning response");
2036
2037 let choice = response
2038 .choices
2039 .pop()
2040 .context("No choices in retrieval response")?;
2041 let open_ai::RequestMessage::Assistant {
2042 content: _,
2043 tool_calls,
2044 } = choice.message
2045 else {
2046 anyhow::bail!("Retrieval response didn't include an assistant message");
2047 };
2048
2049 let mut queries: Vec<SearchToolQuery> = Vec::new();
2050 for tool_call in tool_calls {
2051 let open_ai::ToolCallContent::Function { function } = tool_call.content;
2052 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2053 log::warn!(
2054 "Context retrieval response tried to call an unknown tool: {}",
2055 function.name
2056 );
2057
2058 continue;
2059 }
2060
2061 let input: SearchToolInput = serde_json::from_str(&function.arguments)
2062 .with_context(|| format!("invalid search json {}", &function.arguments))?;
2063 queries.extend(input.queries);
2064 }
2065
2066 if let Some(debug_tx) = &debug_tx {
2067 debug_tx
2068 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2069 ZetaSearchQueryDebugInfo {
2070 project: project.clone(),
2071 timestamp: Instant::now(),
2072 search_queries: queries.clone(),
2073 },
2074 ))
2075 .ok();
2076 }
2077
2078 log::trace!("Running retrieval search: {queries:#?}");
2079
2080 let related_excerpts_result = retrieval_search::run_retrieval_searches(
2081 queries,
2082 project.clone(),
2083 #[cfg(feature = "eval-support")]
2084 eval_cache,
2085 cx,
2086 )
2087 .await;
2088
2089 log::trace!("Search queries executed");
2090
2091 if let Some(debug_tx) = &debug_tx {
2092 debug_tx
2093 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2094 ZetaContextRetrievalDebugInfo {
2095 project: project.clone(),
2096 timestamp: Instant::now(),
2097 },
2098 ))
2099 .ok();
2100 }
2101
2102 this.update(cx, |this, _cx| {
2103 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2104 return Ok(());
2105 };
2106 zeta_project.refresh_context_task.take();
2107 if let Some(debug_tx) = &this.debug_tx {
2108 debug_tx
2109 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2110 ZetaContextRetrievalDebugInfo {
2111 project,
2112 timestamp: Instant::now(),
2113 },
2114 ))
2115 .ok();
2116 }
2117 match related_excerpts_result {
2118 Ok(excerpts) => {
2119 zeta_project.context = Some(excerpts);
2120 Ok(())
2121 }
2122 Err(error) => Err(error),
2123 }
2124 })?
2125 })
2126 }
2127
2128 pub fn set_context(
2129 &mut self,
2130 project: Entity<Project>,
2131 context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2132 ) {
2133 if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2134 zeta_project.context = Some(context);
2135 }
2136 }
2137
2138 fn gather_nearby_diagnostics(
2139 cursor_offset: usize,
2140 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2141 snapshot: &BufferSnapshot,
2142 max_diagnostics_bytes: usize,
2143 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2144 // TODO: Could make this more efficient
2145 let mut diagnostic_groups = Vec::new();
2146 for (language_server_id, diagnostics) in diagnostic_sets {
2147 let mut groups = Vec::new();
2148 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2149 diagnostic_groups.extend(
2150 groups
2151 .into_iter()
2152 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2153 );
2154 }
2155
2156 // sort by proximity to cursor
2157 diagnostic_groups.sort_by_key(|group| {
2158 let range = &group.entries[group.primary_ix].range;
2159 if range.start >= cursor_offset {
2160 range.start - cursor_offset
2161 } else if cursor_offset >= range.end {
2162 cursor_offset - range.end
2163 } else {
2164 (cursor_offset - range.start).min(range.end - cursor_offset)
2165 }
2166 });
2167
2168 let mut results = Vec::new();
2169 let mut diagnostic_groups_truncated = false;
2170 let mut diagnostics_byte_count = 0;
2171 for group in diagnostic_groups {
2172 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2173 diagnostics_byte_count += raw_value.get().len();
2174 if diagnostics_byte_count > max_diagnostics_bytes {
2175 diagnostic_groups_truncated = true;
2176 break;
2177 }
2178 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2179 }
2180
2181 (results, diagnostic_groups_truncated)
2182 }
2183
2184 // TODO: Dedupe with similar code in request_prediction?
2185 pub fn cloud_request_for_zeta_cli(
2186 &mut self,
2187 project: &Entity<Project>,
2188 buffer: &Entity<Buffer>,
2189 position: language::Anchor,
2190 cx: &mut Context<Self>,
2191 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2192 let project_state = self.projects.get(&project.entity_id());
2193
2194 let index_state = project_state.and_then(|state| {
2195 state
2196 .syntax_index
2197 .as_ref()
2198 .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2199 });
2200 let options = self.options.clone();
2201 let snapshot = buffer.read(cx).snapshot();
2202 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2203 return Task::ready(Err(anyhow!("No file path for excerpt")));
2204 };
2205 let worktree_snapshots = project
2206 .read(cx)
2207 .worktrees(cx)
2208 .map(|worktree| worktree.read(cx).snapshot())
2209 .collect::<Vec<_>>();
2210
2211 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2212 let mut path = f.worktree.read(cx).absolutize(&f.path);
2213 if path.pop() { Some(path) } else { None }
2214 });
2215
2216 cx.background_spawn(async move {
2217 let index_state = if let Some(index_state) = index_state {
2218 Some(index_state.lock_owned().await)
2219 } else {
2220 None
2221 };
2222
2223 let cursor_point = position.to_point(&snapshot);
2224
2225 let debug_info = true;
2226 EditPredictionContext::gather_context(
2227 cursor_point,
2228 &snapshot,
2229 parent_abs_path.as_deref(),
2230 match &options.context {
2231 ContextMode::Agentic(_) => {
2232 // TODO
2233 panic!("Llm mode not supported in zeta cli yet");
2234 }
2235 ContextMode::Syntax(edit_prediction_context_options) => {
2236 edit_prediction_context_options
2237 }
2238 },
2239 index_state.as_deref(),
2240 )
2241 .context("Failed to select excerpt")
2242 .map(|context| {
2243 make_syntax_context_cloud_request(
2244 excerpt_path.into(),
2245 context,
2246 // TODO pass everything
2247 Vec::new(),
2248 false,
2249 Vec::new(),
2250 false,
2251 None,
2252 debug_info,
2253 &worktree_snapshots,
2254 index_state.as_deref(),
2255 Some(options.max_prompt_bytes),
2256 options.prompt_format,
2257 )
2258 })
2259 })
2260 }
2261
2262 pub fn wait_for_initial_indexing(
2263 &mut self,
2264 project: &Entity<Project>,
2265 cx: &mut Context<Self>,
2266 ) -> Task<Result<()>> {
2267 let zeta_project = self.get_or_init_zeta_project(project, cx);
2268 if let Some(syntax_index) = &zeta_project.syntax_index {
2269 syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2270 } else {
2271 Task::ready(Ok(()))
2272 }
2273 }
2274}
2275
2276pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2277 let choice = res.choices.pop()?;
2278 let output_text = match choice.message {
2279 open_ai::RequestMessage::Assistant {
2280 content: Some(open_ai::MessageContent::Plain(content)),
2281 ..
2282 } => content,
2283 open_ai::RequestMessage::Assistant {
2284 content: Some(open_ai::MessageContent::Multipart(mut content)),
2285 ..
2286 } => {
2287 if content.is_empty() {
2288 log::error!("No output from Baseten completion response");
2289 return None;
2290 }
2291
2292 match content.remove(0) {
2293 open_ai::MessagePart::Text { text } => text,
2294 open_ai::MessagePart::Image { .. } => {
2295 log::error!("Expected text, got an image");
2296 return None;
2297 }
2298 }
2299 }
2300 _ => {
2301 log::error!("Invalid response message: {:?}", choice.message);
2302 return None;
2303 }
2304 };
2305 Some(output_text)
2306}
2307
2308#[derive(Error, Debug)]
2309#[error(
2310 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2311)]
2312pub struct ZedUpdateRequiredError {
2313 minimum_version: SemanticVersion,
2314}
2315
2316fn make_syntax_context_cloud_request(
2317 excerpt_path: Arc<Path>,
2318 context: EditPredictionContext,
2319 events: Vec<predict_edits_v3::Event>,
2320 can_collect_data: bool,
2321 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2322 diagnostic_groups_truncated: bool,
2323 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2324 debug_info: bool,
2325 worktrees: &Vec<worktree::Snapshot>,
2326 index_state: Option<&SyntaxIndexState>,
2327 prompt_max_bytes: Option<usize>,
2328 prompt_format: PromptFormat,
2329) -> predict_edits_v3::PredictEditsRequest {
2330 let mut signatures = Vec::new();
2331 let mut declaration_to_signature_index = HashMap::default();
2332 let mut referenced_declarations = Vec::new();
2333
2334 for snippet in context.declarations {
2335 let project_entry_id = snippet.declaration.project_entry_id();
2336 let Some(path) = worktrees.iter().find_map(|worktree| {
2337 worktree.entry_for_id(project_entry_id).map(|entry| {
2338 let mut full_path = RelPathBuf::new();
2339 full_path.push(worktree.root_name());
2340 full_path.push(&entry.path);
2341 full_path
2342 })
2343 }) else {
2344 continue;
2345 };
2346
2347 let parent_index = index_state.and_then(|index_state| {
2348 snippet.declaration.parent().and_then(|parent| {
2349 add_signature(
2350 parent,
2351 &mut declaration_to_signature_index,
2352 &mut signatures,
2353 index_state,
2354 )
2355 })
2356 });
2357
2358 let (text, text_is_truncated) = snippet.declaration.item_text();
2359 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2360 path: path.as_std_path().into(),
2361 text: text.into(),
2362 range: snippet.declaration.item_line_range(),
2363 text_is_truncated,
2364 signature_range: snippet.declaration.signature_range_in_item_text(),
2365 parent_index,
2366 signature_score: snippet.score(DeclarationStyle::Signature),
2367 declaration_score: snippet.score(DeclarationStyle::Declaration),
2368 score_components: snippet.components,
2369 });
2370 }
2371
2372 let excerpt_parent = index_state.and_then(|index_state| {
2373 context
2374 .excerpt
2375 .parent_declarations
2376 .last()
2377 .and_then(|(parent, _)| {
2378 add_signature(
2379 *parent,
2380 &mut declaration_to_signature_index,
2381 &mut signatures,
2382 index_state,
2383 )
2384 })
2385 });
2386
2387 predict_edits_v3::PredictEditsRequest {
2388 excerpt_path,
2389 excerpt: context.excerpt_text.body,
2390 excerpt_line_range: context.excerpt.line_range,
2391 excerpt_range: context.excerpt.range,
2392 cursor_point: predict_edits_v3::Point {
2393 line: predict_edits_v3::Line(context.cursor_point.row),
2394 column: context.cursor_point.column,
2395 },
2396 referenced_declarations,
2397 included_files: vec![],
2398 signatures,
2399 excerpt_parent,
2400 events,
2401 can_collect_data,
2402 diagnostic_groups,
2403 diagnostic_groups_truncated,
2404 git_info,
2405 debug_info,
2406 prompt_max_bytes,
2407 prompt_format,
2408 }
2409}
2410
2411fn add_signature(
2412 declaration_id: DeclarationId,
2413 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2414 signatures: &mut Vec<Signature>,
2415 index: &SyntaxIndexState,
2416) -> Option<usize> {
2417 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2418 return Some(*signature_index);
2419 }
2420 let Some(parent_declaration) = index.declaration(declaration_id) else {
2421 log::error!("bug: missing parent declaration");
2422 return None;
2423 };
2424 let parent_index = parent_declaration.parent().and_then(|parent| {
2425 add_signature(parent, declaration_to_signature_index, signatures, index)
2426 });
2427 let (text, text_is_truncated) = parent_declaration.signature_text();
2428 let signature_index = signatures.len();
2429 signatures.push(Signature {
2430 text: text.into(),
2431 text_is_truncated,
2432 parent_index,
2433 range: parent_declaration.signature_line_range(),
2434 });
2435 declaration_to_signature_index.insert(declaration_id, signature_index);
2436 Some(signature_index)
2437}
2438
2439#[cfg(feature = "eval-support")]
2440pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2441
2442#[cfg(feature = "eval-support")]
2443#[derive(Debug, Clone, Copy, PartialEq)]
2444pub enum EvalCacheEntryKind {
2445 Context,
2446 Search,
2447 Prediction,
2448}
2449
2450#[cfg(feature = "eval-support")]
2451impl std::fmt::Display for EvalCacheEntryKind {
2452 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2453 match self {
2454 EvalCacheEntryKind::Search => write!(f, "search"),
2455 EvalCacheEntryKind::Context => write!(f, "context"),
2456 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2457 }
2458 }
2459}
2460
2461#[cfg(feature = "eval-support")]
2462pub trait EvalCache: Send + Sync {
2463 fn read(&self, key: EvalCacheKey) -> Option<String>;
2464 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2465}
2466
2467#[cfg(test)]
2468mod tests {
2469 use std::{path::Path, sync::Arc};
2470
2471 use client::UserStore;
2472 use clock::FakeSystemClock;
2473 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2474 use futures::{
2475 AsyncReadExt, StreamExt,
2476 channel::{mpsc, oneshot},
2477 };
2478 use gpui::{
2479 Entity, TestAppContext,
2480 http_client::{FakeHttpClient, Response},
2481 prelude::*,
2482 };
2483 use indoc::indoc;
2484 use language::OffsetRangeExt as _;
2485 use open_ai::Usage;
2486 use pretty_assertions::{assert_eq, assert_matches};
2487 use project::{FakeFs, Project};
2488 use serde_json::json;
2489 use settings::SettingsStore;
2490 use util::path;
2491 use uuid::Uuid;
2492
2493 use crate::{BufferEditPrediction, Zeta};
2494
2495 #[gpui::test]
2496 async fn test_current_state(cx: &mut TestAppContext) {
2497 let (zeta, mut req_rx) = init_test(cx);
2498 let fs = FakeFs::new(cx.executor());
2499 fs.insert_tree(
2500 "/root",
2501 json!({
2502 "1.txt": "Hello!\nHow\nBye\n",
2503 "2.txt": "Hola!\nComo\nAdios\n"
2504 }),
2505 )
2506 .await;
2507 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2508
2509 zeta.update(cx, |zeta, cx| {
2510 zeta.register_project(&project, cx);
2511 });
2512
2513 let buffer1 = project
2514 .update(cx, |project, cx| {
2515 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2516 project.open_buffer(path, cx)
2517 })
2518 .await
2519 .unwrap();
2520 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2521 let position = snapshot1.anchor_before(language::Point::new(1, 3));
2522
2523 // Prediction for current file
2524
2525 zeta.update(cx, |zeta, cx| {
2526 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2527 });
2528 let (_request, respond_tx) = req_rx.next().await.unwrap();
2529
2530 respond_tx
2531 .send(model_response(indoc! {r"
2532 --- a/root/1.txt
2533 +++ b/root/1.txt
2534 @@ ... @@
2535 Hello!
2536 -How
2537 +How are you?
2538 Bye
2539 "}))
2540 .unwrap();
2541
2542 cx.run_until_parked();
2543
2544 zeta.read_with(cx, |zeta, cx| {
2545 let prediction = zeta
2546 .current_prediction_for_buffer(&buffer1, &project, cx)
2547 .unwrap();
2548 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2549 });
2550
2551 // Context refresh
2552 let refresh_task = zeta.update(cx, |zeta, cx| {
2553 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2554 });
2555 let (_request, respond_tx) = req_rx.next().await.unwrap();
2556 respond_tx
2557 .send(open_ai::Response {
2558 id: Uuid::new_v4().to_string(),
2559 object: "response".into(),
2560 created: 0,
2561 model: "model".into(),
2562 choices: vec![open_ai::Choice {
2563 index: 0,
2564 message: open_ai::RequestMessage::Assistant {
2565 content: None,
2566 tool_calls: vec![open_ai::ToolCall {
2567 id: "search".into(),
2568 content: open_ai::ToolCallContent::Function {
2569 function: open_ai::FunctionContent {
2570 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
2571 .to_string(),
2572 arguments: serde_json::to_string(&SearchToolInput {
2573 queries: Box::new([SearchToolQuery {
2574 glob: "root/2.txt".to_string(),
2575 syntax_node: vec![],
2576 content: Some(".".into()),
2577 }]),
2578 })
2579 .unwrap(),
2580 },
2581 },
2582 }],
2583 },
2584 finish_reason: None,
2585 }],
2586 usage: Usage {
2587 prompt_tokens: 0,
2588 completion_tokens: 0,
2589 total_tokens: 0,
2590 },
2591 })
2592 .unwrap();
2593 refresh_task.await.unwrap();
2594
2595 zeta.update(cx, |zeta, _cx| {
2596 zeta.discard_current_prediction(&project);
2597 });
2598
2599 // Prediction for another file
2600 zeta.update(cx, |zeta, cx| {
2601 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2602 });
2603 let (_request, respond_tx) = req_rx.next().await.unwrap();
2604 respond_tx
2605 .send(model_response(indoc! {r#"
2606 --- a/root/2.txt
2607 +++ b/root/2.txt
2608 Hola!
2609 -Como
2610 +Como estas?
2611 Adios
2612 "#}))
2613 .unwrap();
2614 cx.run_until_parked();
2615
2616 zeta.read_with(cx, |zeta, cx| {
2617 let prediction = zeta
2618 .current_prediction_for_buffer(&buffer1, &project, cx)
2619 .unwrap();
2620 assert_matches!(
2621 prediction,
2622 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
2623 );
2624 });
2625
2626 let buffer2 = project
2627 .update(cx, |project, cx| {
2628 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
2629 project.open_buffer(path, cx)
2630 })
2631 .await
2632 .unwrap();
2633
2634 zeta.read_with(cx, |zeta, cx| {
2635 let prediction = zeta
2636 .current_prediction_for_buffer(&buffer2, &project, cx)
2637 .unwrap();
2638 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2639 });
2640 }
2641
2642 #[gpui::test]
2643 async fn test_simple_request(cx: &mut TestAppContext) {
2644 let (zeta, mut req_rx) = init_test(cx);
2645 let fs = FakeFs::new(cx.executor());
2646 fs.insert_tree(
2647 "/root",
2648 json!({
2649 "foo.md": "Hello!\nHow\nBye\n"
2650 }),
2651 )
2652 .await;
2653 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2654
2655 let buffer = project
2656 .update(cx, |project, cx| {
2657 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2658 project.open_buffer(path, cx)
2659 })
2660 .await
2661 .unwrap();
2662 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2663 let position = snapshot.anchor_before(language::Point::new(1, 3));
2664
2665 let prediction_task = zeta.update(cx, |zeta, cx| {
2666 zeta.request_prediction(&project, &buffer, position, cx)
2667 });
2668
2669 let (_, respond_tx) = req_rx.next().await.unwrap();
2670
2671 // TODO Put back when we have a structured request again
2672 // assert_eq!(
2673 // request.excerpt_path.as_ref(),
2674 // Path::new(path!("root/foo.md"))
2675 // );
2676 // assert_eq!(
2677 // request.cursor_point,
2678 // Point {
2679 // line: Line(1),
2680 // column: 3
2681 // }
2682 // );
2683
2684 respond_tx
2685 .send(model_response(indoc! { r"
2686 --- a/root/foo.md
2687 +++ b/root/foo.md
2688 @@ ... @@
2689 Hello!
2690 -How
2691 +How are you?
2692 Bye
2693 "}))
2694 .unwrap();
2695
2696 let prediction = prediction_task.await.unwrap().unwrap();
2697
2698 assert_eq!(prediction.edits.len(), 1);
2699 assert_eq!(
2700 prediction.edits[0].0.to_point(&snapshot).start,
2701 language::Point::new(1, 3)
2702 );
2703 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2704 }
2705
2706 #[gpui::test]
2707 async fn test_request_events(cx: &mut TestAppContext) {
2708 let (zeta, mut req_rx) = init_test(cx);
2709 let fs = FakeFs::new(cx.executor());
2710 fs.insert_tree(
2711 "/root",
2712 json!({
2713 "foo.md": "Hello!\n\nBye\n"
2714 }),
2715 )
2716 .await;
2717 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2718
2719 let buffer = project
2720 .update(cx, |project, cx| {
2721 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2722 project.open_buffer(path, cx)
2723 })
2724 .await
2725 .unwrap();
2726
2727 zeta.update(cx, |zeta, cx| {
2728 zeta.register_buffer(&buffer, &project, cx);
2729 });
2730
2731 buffer.update(cx, |buffer, cx| {
2732 buffer.edit(vec![(7..7, "How")], None, cx);
2733 });
2734
2735 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2736 let position = snapshot.anchor_before(language::Point::new(1, 3));
2737
2738 let prediction_task = zeta.update(cx, |zeta, cx| {
2739 zeta.request_prediction(&project, &buffer, position, cx)
2740 });
2741
2742 let (request, respond_tx) = req_rx.next().await.unwrap();
2743
2744 let prompt = prompt_from_request(&request);
2745 assert!(
2746 prompt.contains(indoc! {"
2747 --- a/root/foo.md
2748 +++ b/root/foo.md
2749 @@ -1,3 +1,3 @@
2750 Hello!
2751 -
2752 +How
2753 Bye
2754 "}),
2755 "{prompt}"
2756 );
2757
2758 respond_tx
2759 .send(model_response(indoc! {r#"
2760 --- a/root/foo.md
2761 +++ b/root/foo.md
2762 @@ ... @@
2763 Hello!
2764 -How
2765 +How are you?
2766 Bye
2767 "#}))
2768 .unwrap();
2769
2770 let prediction = prediction_task.await.unwrap().unwrap();
2771
2772 assert_eq!(prediction.edits.len(), 1);
2773 assert_eq!(
2774 prediction.edits[0].0.to_point(&snapshot).start,
2775 language::Point::new(1, 3)
2776 );
2777 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2778 }
2779
2780 // Skipped until we start including diagnostics in prompt
2781 // #[gpui::test]
2782 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
2783 // let (zeta, mut req_rx) = init_test(cx);
2784 // let fs = FakeFs::new(cx.executor());
2785 // fs.insert_tree(
2786 // "/root",
2787 // json!({
2788 // "foo.md": "Hello!\nBye"
2789 // }),
2790 // )
2791 // .await;
2792 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2793
2794 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2795 // let diagnostic = lsp::Diagnostic {
2796 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2797 // severity: Some(lsp::DiagnosticSeverity::ERROR),
2798 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2799 // ..Default::default()
2800 // };
2801
2802 // project.update(cx, |project, cx| {
2803 // project.lsp_store().update(cx, |lsp_store, cx| {
2804 // // Create some diagnostics
2805 // lsp_store
2806 // .update_diagnostics(
2807 // LanguageServerId(0),
2808 // lsp::PublishDiagnosticsParams {
2809 // uri: path_to_buffer_uri.clone(),
2810 // diagnostics: vec![diagnostic],
2811 // version: None,
2812 // },
2813 // None,
2814 // language::DiagnosticSourceKind::Pushed,
2815 // &[],
2816 // cx,
2817 // )
2818 // .unwrap();
2819 // });
2820 // });
2821
2822 // let buffer = project
2823 // .update(cx, |project, cx| {
2824 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2825 // project.open_buffer(path, cx)
2826 // })
2827 // .await
2828 // .unwrap();
2829
2830 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2831 // let position = snapshot.anchor_before(language::Point::new(0, 0));
2832
2833 // let _prediction_task = zeta.update(cx, |zeta, cx| {
2834 // zeta.request_prediction(&project, &buffer, position, cx)
2835 // });
2836
2837 // let (request, _respond_tx) = req_rx.next().await.unwrap();
2838
2839 // assert_eq!(request.diagnostic_groups.len(), 1);
2840 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2841 // .unwrap();
2842 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2843 // assert_eq!(
2844 // value,
2845 // json!({
2846 // "entries": [{
2847 // "range": {
2848 // "start": 8,
2849 // "end": 10
2850 // },
2851 // "diagnostic": {
2852 // "source": null,
2853 // "code": null,
2854 // "code_description": null,
2855 // "severity": 1,
2856 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2857 // "markdown": null,
2858 // "group_id": 0,
2859 // "is_primary": true,
2860 // "is_disk_based": false,
2861 // "is_unnecessary": false,
2862 // "source_kind": "Pushed",
2863 // "data": null,
2864 // "underline": true
2865 // }
2866 // }],
2867 // "primary_ix": 0
2868 // })
2869 // );
2870 // }
2871
2872 fn model_response(text: &str) -> open_ai::Response {
2873 open_ai::Response {
2874 id: Uuid::new_v4().to_string(),
2875 object: "response".into(),
2876 created: 0,
2877 model: "model".into(),
2878 choices: vec![open_ai::Choice {
2879 index: 0,
2880 message: open_ai::RequestMessage::Assistant {
2881 content: Some(open_ai::MessageContent::Plain(text.to_string())),
2882 tool_calls: vec![],
2883 },
2884 finish_reason: None,
2885 }],
2886 usage: Usage {
2887 prompt_tokens: 0,
2888 completion_tokens: 0,
2889 total_tokens: 0,
2890 },
2891 }
2892 }
2893
2894 fn prompt_from_request(request: &open_ai::Request) -> &str {
2895 assert_eq!(request.messages.len(), 1);
2896 let open_ai::RequestMessage::User {
2897 content: open_ai::MessageContent::Plain(content),
2898 ..
2899 } = &request.messages[0]
2900 else {
2901 panic!(
2902 "Request does not have single user message of type Plain. {:#?}",
2903 request
2904 );
2905 };
2906 content
2907 }
2908
2909 fn init_test(
2910 cx: &mut TestAppContext,
2911 ) -> (
2912 Entity<Zeta>,
2913 mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2914 ) {
2915 cx.update(move |cx| {
2916 let settings_store = SettingsStore::test(cx);
2917 cx.set_global(settings_store);
2918 zlog::init_test();
2919
2920 let (req_tx, req_rx) = mpsc::unbounded();
2921
2922 let http_client = FakeHttpClient::create({
2923 move |req| {
2924 let uri = req.uri().path().to_string();
2925 let mut body = req.into_body();
2926 let req_tx = req_tx.clone();
2927 async move {
2928 let resp = match uri.as_str() {
2929 "/client/llm_tokens" => serde_json::to_string(&json!({
2930 "token": "test"
2931 }))
2932 .unwrap(),
2933 "/predict_edits/raw" => {
2934 let mut buf = Vec::new();
2935 body.read_to_end(&mut buf).await.ok();
2936 let req = serde_json::from_slice(&buf).unwrap();
2937
2938 let (res_tx, res_rx) = oneshot::channel();
2939 req_tx.unbounded_send((req, res_tx)).unwrap();
2940 serde_json::to_string(&res_rx.await?).unwrap()
2941 }
2942 _ => {
2943 panic!("Unexpected path: {}", uri)
2944 }
2945 };
2946
2947 Ok(Response::builder().body(resp.into()).unwrap())
2948 }
2949 }
2950 });
2951
2952 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2953 client.cloud_client().set_credentials(1, "test".into());
2954
2955 language_model::init(client.clone(), cx);
2956
2957 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2958 let zeta = Zeta::global(&client, &user_store, cx);
2959
2960 (zeta, req_rx)
2961 })
2962 }
2963}