1use anyhow::{Context as _, Result, anyhow, bail};
2use arrayvec::ArrayVec;
3use chrono::TimeDelta;
4use client::{Client, EditPredictionUsage, UserStore};
5use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
6use cloud_llm_client::{
7 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
8 ZED_VERSION_HEADER_NAME,
9};
10use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
11use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
12use collections::HashMap;
13use edit_prediction_context::{
14 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
15 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
16 SyntaxIndex, SyntaxIndexState,
17};
18use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
19use futures::AsyncReadExt as _;
20use futures::channel::{mpsc, oneshot};
21use gpui::http_client::{AsyncBody, Method};
22use gpui::{
23 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity,
24 http_client, prelude::*,
25};
26use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, Point, ToOffset as _, ToPoint};
27use language::{BufferSnapshot, OffsetRangeExt};
28use language_model::{LlmApiToken, RefreshLlmTokenListener};
29use lsp::DiagnosticSeverity;
30use open_ai::FunctionDefinition;
31use project::{Project, ProjectPath};
32use release_channel::AppVersion;
33use semver::Version;
34use serde::de::DeserializeOwned;
35use std::collections::{VecDeque, hash_map};
36
37use std::fmt::Write;
38use std::ops::Range;
39use std::path::Path;
40use std::str::FromStr;
41use std::sync::{Arc, LazyLock};
42use std::time::{Duration, Instant};
43use std::{env, mem};
44use thiserror::Error;
45use util::rel_path::RelPathBuf;
46use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
47use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
48
49pub mod assemble_excerpts;
50mod prediction;
51mod provider;
52pub mod retrieval_search;
53mod sweep_ai;
54pub mod udiff;
55mod xml_edits;
56
57use crate::assemble_excerpts::assemble_excerpts;
58pub use crate::prediction::EditPrediction;
59pub use crate::prediction::EditPredictionId;
60pub use provider::ZetaEditPredictionProvider;
61
62/// Maximum number of events to track.
63const EVENT_COUNT_MAX_SWEEP: usize = 6;
64const EVENT_COUNT_MAX_ZETA: usize = 16;
65const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
66
67pub struct SweepFeatureFlag;
68
69impl FeatureFlag for SweepFeatureFlag {
70 const NAME: &str = "sweep-ai";
71}
72pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
73 max_bytes: 512,
74 min_bytes: 128,
75 target_before_cursor_over_total_bytes: 0.5,
76};
77
78pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
79 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
80
81pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
82 excerpt: DEFAULT_EXCERPT_OPTIONS,
83};
84
85pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
86 EditPredictionContextOptions {
87 use_imports: true,
88 max_retrieved_declarations: 0,
89 excerpt: DEFAULT_EXCERPT_OPTIONS,
90 score: EditPredictionScoreOptions {
91 omit_excerpt_overlaps: true,
92 },
93 };
94
95pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
96 context: DEFAULT_CONTEXT_OPTIONS,
97 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
98 max_diagnostic_bytes: 2048,
99 prompt_format: PromptFormat::DEFAULT,
100 file_indexing_parallelism: 1,
101 buffer_change_grouping_interval: Duration::from_secs(1),
102};
103
104static USE_OLLAMA: LazyLock<bool> =
105 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
106static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
107 env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
108 "qwen3-coder:30b".to_string()
109 } else {
110 "yqvev8r3".to_string()
111 })
112});
113static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
114 match env::var("ZED_ZETA2_MODEL").as_deref() {
115 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
116 Ok(model) => model,
117 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
118 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
119 }
120 .to_string()
121});
122static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
123 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
124 if *USE_OLLAMA {
125 Some("http://localhost:11434/v1/chat/completions".into())
126 } else {
127 None
128 }
129 })
130});
131
132pub struct Zeta2FeatureFlag;
133
134impl FeatureFlag for Zeta2FeatureFlag {
135 const NAME: &'static str = "zeta2";
136
137 fn enabled_for_staff() -> bool {
138 false
139 }
140}
141
142#[derive(Clone)]
143struct ZetaGlobal(Entity<Zeta>);
144
145impl Global for ZetaGlobal {}
146
147pub struct Zeta {
148 client: Arc<Client>,
149 user_store: Entity<UserStore>,
150 llm_token: LlmApiToken,
151 _llm_token_subscription: Subscription,
152 projects: HashMap<EntityId, ZetaProject>,
153 options: ZetaOptions,
154 update_required: bool,
155 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
156 #[cfg(feature = "eval-support")]
157 eval_cache: Option<Arc<dyn EvalCache>>,
158 edit_prediction_model: ZetaEditPredictionModel,
159 sweep_api_token: Option<String>,
160 sweep_ai_debug_info: Arc<str>,
161}
162
163#[derive(PartialEq, Eq)]
164pub enum ZetaEditPredictionModel {
165 ZedCloud,
166 Sweep,
167}
168
169#[derive(Debug, Clone, PartialEq)]
170pub struct ZetaOptions {
171 pub context: ContextMode,
172 pub max_prompt_bytes: usize,
173 pub max_diagnostic_bytes: usize,
174 pub prompt_format: predict_edits_v3::PromptFormat,
175 pub file_indexing_parallelism: usize,
176 pub buffer_change_grouping_interval: Duration,
177}
178
179#[derive(Debug, Clone, PartialEq)]
180pub enum ContextMode {
181 Agentic(AgenticContextOptions),
182 Syntax(EditPredictionContextOptions),
183}
184
185#[derive(Debug, Clone, PartialEq)]
186pub struct AgenticContextOptions {
187 pub excerpt: EditPredictionExcerptOptions,
188}
189
190impl ContextMode {
191 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
192 match self {
193 ContextMode::Agentic(options) => &options.excerpt,
194 ContextMode::Syntax(options) => &options.excerpt,
195 }
196 }
197}
198
199#[derive(Debug)]
200pub enum ZetaDebugInfo {
201 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
202 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
203 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
204 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
205 EditPredictionRequested(ZetaEditPredictionDebugInfo),
206}
207
208#[derive(Debug)]
209pub struct ZetaContextRetrievalStartedDebugInfo {
210 pub project: Entity<Project>,
211 pub timestamp: Instant,
212 pub search_prompt: String,
213}
214
215#[derive(Debug)]
216pub struct ZetaContextRetrievalDebugInfo {
217 pub project: Entity<Project>,
218 pub timestamp: Instant,
219}
220
221#[derive(Debug)]
222pub struct ZetaEditPredictionDebugInfo {
223 pub request: predict_edits_v3::PredictEditsRequest,
224 pub retrieval_time: TimeDelta,
225 pub buffer: WeakEntity<Buffer>,
226 pub position: language::Anchor,
227 pub local_prompt: Result<String, String>,
228 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
229}
230
231#[derive(Debug)]
232pub struct ZetaSearchQueryDebugInfo {
233 pub project: Entity<Project>,
234 pub timestamp: Instant,
235 pub search_queries: Vec<SearchToolQuery>,
236}
237
238pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
239
240struct ZetaProject {
241 syntax_index: Option<Entity<SyntaxIndex>>,
242 events: VecDeque<Event>,
243 recent_paths: VecDeque<ProjectPath>,
244 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
245 current_prediction: Option<CurrentEditPrediction>,
246 next_pending_prediction_id: usize,
247 pending_predictions: ArrayVec<PendingPrediction, 2>,
248 last_prediction_refresh: Option<(EntityId, Instant)>,
249 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
250 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
251 refresh_context_debounce_task: Option<Task<Option<()>>>,
252 refresh_context_timestamp: Option<Instant>,
253 _subscription: gpui::Subscription,
254}
255
256#[derive(Debug, Clone)]
257struct CurrentEditPrediction {
258 pub requested_by: PredictionRequestedBy,
259 pub prediction: EditPrediction,
260}
261
262impl CurrentEditPrediction {
263 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
264 let Some(new_edits) = self
265 .prediction
266 .interpolate(&self.prediction.buffer.read(cx))
267 else {
268 return false;
269 };
270
271 if self.prediction.buffer != old_prediction.prediction.buffer {
272 return true;
273 }
274
275 let Some(old_edits) = old_prediction
276 .prediction
277 .interpolate(&old_prediction.prediction.buffer.read(cx))
278 else {
279 return true;
280 };
281
282 let requested_by_buffer_id = self.requested_by.buffer_id();
283
284 // This reduces the occurrence of UI thrash from replacing edits
285 //
286 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
287 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
288 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
289 && old_edits.len() == 1
290 && new_edits.len() == 1
291 {
292 let (old_range, old_text) = &old_edits[0];
293 let (new_range, new_text) = &new_edits[0];
294 new_range == old_range && new_text.starts_with(old_text.as_ref())
295 } else {
296 true
297 }
298 }
299}
300
301#[derive(Debug, Clone)]
302enum PredictionRequestedBy {
303 DiagnosticsUpdate,
304 Buffer(EntityId),
305}
306
307impl PredictionRequestedBy {
308 pub fn buffer_id(&self) -> Option<EntityId> {
309 match self {
310 PredictionRequestedBy::DiagnosticsUpdate => None,
311 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
312 }
313 }
314}
315
316struct PendingPrediction {
317 id: usize,
318 _task: Task<()>,
319}
320
321/// A prediction from the perspective of a buffer.
322#[derive(Debug)]
323enum BufferEditPrediction<'a> {
324 Local { prediction: &'a EditPrediction },
325 Jump { prediction: &'a EditPrediction },
326}
327
328struct RegisteredBuffer {
329 snapshot: BufferSnapshot,
330 _subscriptions: [gpui::Subscription; 2],
331}
332
333#[derive(Clone)]
334pub enum Event {
335 BufferChange {
336 old_snapshot: BufferSnapshot,
337 new_snapshot: BufferSnapshot,
338 end_edit_anchor: Option<Anchor>,
339 timestamp: Instant,
340 },
341}
342
343impl Event {
344 pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
345 match self {
346 Event::BufferChange {
347 old_snapshot,
348 new_snapshot,
349 ..
350 } => {
351 let path = new_snapshot.file().map(|f| f.full_path(cx));
352
353 let old_path = old_snapshot.file().and_then(|f| {
354 let old_path = f.full_path(cx);
355 if Some(&old_path) != path.as_ref() {
356 Some(old_path)
357 } else {
358 None
359 }
360 });
361
362 // TODO [zeta2] move to bg?
363 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
364
365 if path == old_path && diff.is_empty() {
366 None
367 } else {
368 Some(predict_edits_v3::Event::BufferChange {
369 old_path,
370 path,
371 diff,
372 //todo: Actually detect if this edit was predicted or not
373 predicted: false,
374 })
375 }
376 }
377 }
378 }
379
380 pub fn project_path(&self, cx: &App) -> Option<project::ProjectPath> {
381 match self {
382 Event::BufferChange { new_snapshot, .. } => new_snapshot
383 .file()
384 .map(|f| project::ProjectPath::from_file(f.as_ref(), cx)),
385 }
386 }
387}
388
389impl Zeta {
390 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
391 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
392 }
393
394 pub fn global(
395 client: &Arc<Client>,
396 user_store: &Entity<UserStore>,
397 cx: &mut App,
398 ) -> Entity<Self> {
399 cx.try_global::<ZetaGlobal>()
400 .map(|global| global.0.clone())
401 .unwrap_or_else(|| {
402 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
403 cx.set_global(ZetaGlobal(zeta.clone()));
404 zeta
405 })
406 }
407
408 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
409 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
410
411 Self {
412 projects: HashMap::default(),
413 client,
414 user_store,
415 options: DEFAULT_OPTIONS,
416 llm_token: LlmApiToken::default(),
417 _llm_token_subscription: cx.subscribe(
418 &refresh_llm_token_listener,
419 |this, _listener, _event, cx| {
420 let client = this.client.clone();
421 let llm_token = this.llm_token.clone();
422 cx.spawn(async move |_this, _cx| {
423 llm_token.refresh(&client).await?;
424 anyhow::Ok(())
425 })
426 .detach_and_log_err(cx);
427 },
428 ),
429 update_required: false,
430 debug_tx: None,
431 #[cfg(feature = "eval-support")]
432 eval_cache: None,
433 edit_prediction_model: ZetaEditPredictionModel::ZedCloud,
434 sweep_api_token: std::env::var("SWEEP_AI_TOKEN")
435 .context("No SWEEP_AI_TOKEN environment variable set")
436 .log_err(),
437 sweep_ai_debug_info: sweep_ai::debug_info(cx),
438 }
439 }
440
441 pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
442 self.edit_prediction_model = model;
443 }
444
445 pub fn has_sweep_api_token(&self) -> bool {
446 self.sweep_api_token.is_some()
447 }
448
449 #[cfg(feature = "eval-support")]
450 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
451 self.eval_cache = Some(cache);
452 }
453
454 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
455 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
456 self.debug_tx = Some(debug_watch_tx);
457 debug_watch_rx
458 }
459
460 pub fn options(&self) -> &ZetaOptions {
461 &self.options
462 }
463
464 pub fn set_options(&mut self, options: ZetaOptions) {
465 self.options = options;
466 }
467
468 pub fn clear_history(&mut self) {
469 for zeta_project in self.projects.values_mut() {
470 zeta_project.events.clear();
471 }
472 }
473
474 pub fn history_for_project(
475 &self,
476 project: &Entity<Project>,
477 ) -> impl DoubleEndedIterator<Item = &Event> {
478 self.projects
479 .get(&project.entity_id())
480 .map(|project| project.events.iter())
481 .into_iter()
482 .flatten()
483 }
484
485 pub fn context_for_project(
486 &self,
487 project: &Entity<Project>,
488 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
489 self.projects
490 .get(&project.entity_id())
491 .and_then(|project| {
492 Some(
493 project
494 .context
495 .as_ref()?
496 .iter()
497 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
498 )
499 })
500 .into_iter()
501 .flatten()
502 }
503
504 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
505 if self.edit_prediction_model == ZetaEditPredictionModel::ZedCloud {
506 self.user_store.read(cx).edit_prediction_usage()
507 } else {
508 None
509 }
510 }
511
512 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
513 self.get_or_init_zeta_project(project, cx);
514 }
515
516 pub fn register_buffer(
517 &mut self,
518 buffer: &Entity<Buffer>,
519 project: &Entity<Project>,
520 cx: &mut Context<Self>,
521 ) {
522 let zeta_project = self.get_or_init_zeta_project(project, cx);
523 Self::register_buffer_impl(zeta_project, buffer, project, cx);
524 }
525
526 fn get_or_init_zeta_project(
527 &mut self,
528 project: &Entity<Project>,
529 cx: &mut Context<Self>,
530 ) -> &mut ZetaProject {
531 self.projects
532 .entry(project.entity_id())
533 .or_insert_with(|| ZetaProject {
534 syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
535 Some(cx.new(|cx| {
536 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
537 }))
538 } else {
539 None
540 },
541 events: VecDeque::new(),
542 recent_paths: VecDeque::new(),
543 registered_buffers: HashMap::default(),
544 current_prediction: None,
545 pending_predictions: ArrayVec::new(),
546 next_pending_prediction_id: 0,
547 last_prediction_refresh: None,
548 context: None,
549 refresh_context_task: None,
550 refresh_context_debounce_task: None,
551 refresh_context_timestamp: None,
552 _subscription: cx.subscribe(&project, Self::handle_project_event),
553 })
554 }
555
556 fn handle_project_event(
557 &mut self,
558 project: Entity<Project>,
559 event: &project::Event,
560 cx: &mut Context<Self>,
561 ) {
562 // TODO [zeta2] init with recent paths
563 match event {
564 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
565 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
566 return;
567 };
568 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
569 if let Some(path) = path {
570 if let Some(ix) = zeta_project
571 .recent_paths
572 .iter()
573 .position(|probe| probe == &path)
574 {
575 zeta_project.recent_paths.remove(ix);
576 }
577 zeta_project.recent_paths.push_front(path);
578 }
579 }
580 project::Event::DiagnosticsUpdated { .. } => {
581 self.refresh_prediction_from_diagnostics(project, cx);
582 }
583 _ => (),
584 }
585 }
586
587 fn register_buffer_impl<'a>(
588 zeta_project: &'a mut ZetaProject,
589 buffer: &Entity<Buffer>,
590 project: &Entity<Project>,
591 cx: &mut Context<Self>,
592 ) -> &'a mut RegisteredBuffer {
593 let buffer_id = buffer.entity_id();
594 match zeta_project.registered_buffers.entry(buffer_id) {
595 hash_map::Entry::Occupied(entry) => entry.into_mut(),
596 hash_map::Entry::Vacant(entry) => {
597 let snapshot = buffer.read(cx).snapshot();
598 let project_entity_id = project.entity_id();
599 entry.insert(RegisteredBuffer {
600 snapshot,
601 _subscriptions: [
602 cx.subscribe(buffer, {
603 let project = project.downgrade();
604 move |this, buffer, event, cx| {
605 if let language::BufferEvent::Edited = event
606 && let Some(project) = project.upgrade()
607 {
608 this.report_changes_for_buffer(&buffer, &project, cx);
609 }
610 }
611 }),
612 cx.observe_release(buffer, move |this, _buffer, _cx| {
613 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
614 else {
615 return;
616 };
617 zeta_project.registered_buffers.remove(&buffer_id);
618 }),
619 ],
620 })
621 }
622 }
623 }
624
625 fn report_changes_for_buffer(
626 &mut self,
627 buffer: &Entity<Buffer>,
628 project: &Entity<Project>,
629 cx: &mut Context<Self>,
630 ) {
631 let event_count_max = match self.edit_prediction_model {
632 ZetaEditPredictionModel::ZedCloud => EVENT_COUNT_MAX_ZETA,
633 ZetaEditPredictionModel::Sweep => EVENT_COUNT_MAX_SWEEP,
634 };
635
636 let sweep_ai_project = self.get_or_init_zeta_project(project, cx);
637 let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
638
639 let new_snapshot = buffer.read(cx).snapshot();
640 if new_snapshot.version == registered_buffer.snapshot.version {
641 return;
642 }
643
644 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
645 let end_edit_anchor = new_snapshot
646 .anchored_edits_since::<Point>(&old_snapshot.version)
647 .last()
648 .map(|(_, range)| range.end);
649 let events = &mut sweep_ai_project.events;
650
651 if let Some(Event::BufferChange {
652 new_snapshot: last_new_snapshot,
653 end_edit_anchor: last_end_edit_anchor,
654 ..
655 }) = events.back_mut()
656 {
657 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
658 == last_new_snapshot.remote_id()
659 && old_snapshot.version == last_new_snapshot.version;
660
661 let should_coalesce = is_next_snapshot_of_same_buffer
662 && end_edit_anchor
663 .as_ref()
664 .zip(last_end_edit_anchor.as_ref())
665 .is_some_and(|(a, b)| {
666 let a = a.to_point(&new_snapshot);
667 let b = b.to_point(&new_snapshot);
668 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
669 });
670
671 if should_coalesce {
672 *last_end_edit_anchor = end_edit_anchor;
673 *last_new_snapshot = new_snapshot;
674 return;
675 }
676 }
677
678 if events.len() >= event_count_max {
679 events.pop_front();
680 }
681
682 events.push_back(Event::BufferChange {
683 old_snapshot,
684 new_snapshot,
685 end_edit_anchor,
686 timestamp: Instant::now(),
687 });
688 }
689
690 fn current_prediction_for_buffer(
691 &self,
692 buffer: &Entity<Buffer>,
693 project: &Entity<Project>,
694 cx: &App,
695 ) -> Option<BufferEditPrediction<'_>> {
696 let project_state = self.projects.get(&project.entity_id())?;
697
698 let CurrentEditPrediction {
699 requested_by,
700 prediction,
701 } = project_state.current_prediction.as_ref()?;
702
703 if prediction.targets_buffer(buffer.read(cx)) {
704 Some(BufferEditPrediction::Local { prediction })
705 } else {
706 let show_jump = match requested_by {
707 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
708 requested_by_buffer_id == &buffer.entity_id()
709 }
710 PredictionRequestedBy::DiagnosticsUpdate => true,
711 };
712
713 if show_jump {
714 Some(BufferEditPrediction::Jump { prediction })
715 } else {
716 None
717 }
718 }
719 }
720
721 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
722 if self.edit_prediction_model != ZetaEditPredictionModel::ZedCloud {
723 return;
724 }
725
726 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
727 return;
728 };
729
730 let Some(prediction) = project_state.current_prediction.take() else {
731 return;
732 };
733 let request_id = prediction.prediction.id.to_string();
734 project_state.pending_predictions.clear();
735
736 let client = self.client.clone();
737 let llm_token = self.llm_token.clone();
738 let app_version = AppVersion::global(cx);
739 cx.spawn(async move |this, cx| {
740 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
741 http_client::Url::parse(&predict_edits_url)?
742 } else {
743 client
744 .http_client()
745 .build_zed_llm_url("/predict_edits/accept", &[])?
746 };
747
748 let response = cx
749 .background_spawn(Self::send_api_request::<()>(
750 move |builder| {
751 let req = builder.uri(url.as_ref()).body(
752 serde_json::to_string(&AcceptEditPredictionBody {
753 request_id: request_id.clone(),
754 })?
755 .into(),
756 );
757 Ok(req?)
758 },
759 client,
760 llm_token,
761 app_version,
762 ))
763 .await;
764
765 Self::handle_api_response(&this, response, cx)?;
766 anyhow::Ok(())
767 })
768 .detach_and_log_err(cx);
769 }
770
771 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
772 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
773 project_state.current_prediction.take();
774 project_state.pending_predictions.clear();
775 };
776 }
777
778 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
779 self.projects
780 .get(&project.entity_id())
781 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
782 }
783
784 pub fn refresh_prediction_from_buffer(
785 &mut self,
786 project: Entity<Project>,
787 buffer: Entity<Buffer>,
788 position: language::Anchor,
789 cx: &mut Context<Self>,
790 ) {
791 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
792 let Some(request_task) = this
793 .update(cx, |this, cx| {
794 this.request_prediction(&project, &buffer, position, cx)
795 })
796 .log_err()
797 else {
798 return Task::ready(anyhow::Ok(()));
799 };
800
801 let project = project.clone();
802 cx.spawn(async move |cx| {
803 if let Some(prediction) = request_task.await? {
804 this.update(cx, |this, cx| {
805 let project_state = this
806 .projects
807 .get_mut(&project.entity_id())
808 .context("Project not found")?;
809
810 let new_prediction = CurrentEditPrediction {
811 requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()),
812 prediction: prediction,
813 };
814
815 if project_state
816 .current_prediction
817 .as_ref()
818 .is_none_or(|old_prediction| {
819 new_prediction.should_replace_prediction(&old_prediction, cx)
820 })
821 {
822 project_state.current_prediction = Some(new_prediction);
823 cx.notify();
824 }
825 anyhow::Ok(())
826 })??;
827 }
828 Ok(())
829 })
830 })
831 }
832
833 pub fn refresh_prediction_from_diagnostics(
834 &mut self,
835 project: Entity<Project>,
836 cx: &mut Context<Self>,
837 ) {
838 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
839 return;
840 };
841
842 // Prefer predictions from buffer
843 if zeta_project.current_prediction.is_some() {
844 return;
845 };
846
847 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
848 let Some(open_buffer_task) = project
849 .update(cx, |project, cx| {
850 project
851 .active_entry()
852 .and_then(|entry| project.path_for_entry(entry, cx))
853 .map(|path| project.open_buffer(path, cx))
854 })
855 .log_err()
856 .flatten()
857 else {
858 return Task::ready(anyhow::Ok(()));
859 };
860
861 cx.spawn(async move |cx| {
862 let active_buffer = open_buffer_task.await?;
863 let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
864
865 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
866 active_buffer,
867 &snapshot,
868 Default::default(),
869 Default::default(),
870 &project,
871 cx,
872 )
873 .await?
874 else {
875 return anyhow::Ok(());
876 };
877
878 let Some(prediction) = this
879 .update(cx, |this, cx| {
880 this.request_prediction(&project, &jump_buffer, jump_position, cx)
881 })?
882 .await?
883 else {
884 return anyhow::Ok(());
885 };
886
887 this.update(cx, |this, cx| {
888 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
889 zeta_project.current_prediction.get_or_insert_with(|| {
890 cx.notify();
891 CurrentEditPrediction {
892 requested_by: PredictionRequestedBy::DiagnosticsUpdate,
893 prediction,
894 }
895 });
896 }
897 })?;
898
899 anyhow::Ok(())
900 })
901 });
902 }
903
904 #[cfg(not(test))]
905 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
906 #[cfg(test)]
907 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
908
909 fn queue_prediction_refresh(
910 &mut self,
911 project: Entity<Project>,
912 throttle_entity: EntityId,
913 cx: &mut Context<Self>,
914 do_refresh: impl FnOnce(WeakEntity<Self>, &mut AsyncApp) -> Task<Result<()>> + 'static,
915 ) {
916 let zeta_project = self.get_or_init_zeta_project(&project, cx);
917 let pending_prediction_id = zeta_project.next_pending_prediction_id;
918 zeta_project.next_pending_prediction_id += 1;
919 let last_request = zeta_project.last_prediction_refresh;
920
921 // TODO report cancelled requests like in zeta1
922 let task = cx.spawn(async move |this, cx| {
923 if let Some((last_entity, last_timestamp)) = last_request
924 && throttle_entity == last_entity
925 && let Some(timeout) =
926 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
927 {
928 cx.background_executor().timer(timeout).await;
929 }
930
931 do_refresh(this.clone(), cx).await.log_err();
932
933 this.update(cx, |this, cx| {
934 let zeta_project = this.get_or_init_zeta_project(&project, cx);
935
936 if zeta_project.pending_predictions[0].id == pending_prediction_id {
937 zeta_project.pending_predictions.remove(0);
938 } else {
939 zeta_project.pending_predictions.clear();
940 }
941
942 cx.notify();
943 })
944 .ok();
945 });
946
947 if zeta_project.pending_predictions.len() <= 1 {
948 zeta_project.pending_predictions.push(PendingPrediction {
949 id: pending_prediction_id,
950 _task: task,
951 });
952 } else if zeta_project.pending_predictions.len() == 2 {
953 zeta_project.pending_predictions.pop();
954 zeta_project.pending_predictions.push(PendingPrediction {
955 id: pending_prediction_id,
956 _task: task,
957 });
958 }
959 }
960
961 pub fn request_prediction(
962 &mut self,
963 project: &Entity<Project>,
964 active_buffer: &Entity<Buffer>,
965 position: language::Anchor,
966 cx: &mut Context<Self>,
967 ) -> Task<Result<Option<EditPrediction>>> {
968 match self.edit_prediction_model {
969 ZetaEditPredictionModel::ZedCloud => {
970 self.request_prediction_with_zed_cloud(project, active_buffer, position, cx)
971 }
972 ZetaEditPredictionModel::Sweep => {
973 self.request_prediction_with_sweep(project, active_buffer, position, true, cx)
974 }
975 }
976 }
977
978 fn request_prediction_with_sweep(
979 &mut self,
980 project: &Entity<Project>,
981 active_buffer: &Entity<Buffer>,
982 position: language::Anchor,
983 allow_jump: bool,
984 cx: &mut Context<Self>,
985 ) -> Task<Result<Option<EditPrediction>>> {
986 let snapshot = active_buffer.read(cx).snapshot();
987 let debug_info = self.sweep_ai_debug_info.clone();
988 let Some(api_token) = self.sweep_api_token.clone() else {
989 return Task::ready(Ok(None));
990 };
991 let full_path: Arc<Path> = snapshot
992 .file()
993 .map(|file| file.full_path(cx))
994 .unwrap_or_else(|| "untitled".into())
995 .into();
996
997 let project_file = project::File::from_dyn(snapshot.file());
998 let repo_name = project_file
999 .map(|file| file.worktree.read(cx).root_name_str())
1000 .unwrap_or("untitled")
1001 .into();
1002 let offset = position.to_offset(&snapshot);
1003
1004 let project_state = self.get_or_init_zeta_project(project, cx);
1005 let events = project_state.events.clone();
1006 let has_events = !events.is_empty();
1007 let recent_buffers = project_state.recent_paths.iter().cloned();
1008 let http_client = cx.http_client();
1009
1010 let recent_buffer_snapshots = recent_buffers
1011 .filter_map(|project_path| {
1012 let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
1013 if active_buffer == &buffer {
1014 None
1015 } else {
1016 Some(buffer.read(cx).snapshot())
1017 }
1018 })
1019 .take(3)
1020 .collect::<Vec<_>>();
1021
1022 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1023
1024 let cursor_point = position.to_point(&snapshot);
1025 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1026 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1027 let diagnostic_search_range =
1028 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1029
1030 let result = cx.background_spawn({
1031 let snapshot = snapshot.clone();
1032 let diagnostic_search_range = diagnostic_search_range.clone();
1033 async move {
1034 let text = snapshot.text();
1035
1036 let mut recent_changes = String::new();
1037 for event in events {
1038 sweep_ai::write_event(event, &mut recent_changes).unwrap();
1039 }
1040
1041 let mut file_chunks = recent_buffer_snapshots
1042 .into_iter()
1043 .map(|snapshot| {
1044 let end_point = Point::new(30, 0).min(snapshot.max_point());
1045 sweep_ai::FileChunk {
1046 content: snapshot.text_for_range(Point::zero()..end_point).collect(),
1047 file_path: snapshot
1048 .file()
1049 .map(|f| f.path().as_unix_str())
1050 .unwrap_or("untitled")
1051 .to_string(),
1052 start_line: 0,
1053 end_line: end_point.row as usize,
1054 timestamp: snapshot.file().and_then(|file| {
1055 Some(
1056 file.disk_state()
1057 .mtime()?
1058 .to_seconds_and_nanos_for_persistence()?
1059 .0,
1060 )
1061 }),
1062 }
1063 })
1064 .collect::<Vec<_>>();
1065
1066 let diagnostic_entries =
1067 snapshot.diagnostics_in_range(diagnostic_search_range, false);
1068 let mut diagnostic_content = String::new();
1069 let mut diagnostic_count = 0;
1070
1071 for entry in diagnostic_entries {
1072 let start_point: Point = entry.range.start;
1073
1074 let severity = match entry.diagnostic.severity {
1075 DiagnosticSeverity::ERROR => "error",
1076 DiagnosticSeverity::WARNING => "warning",
1077 DiagnosticSeverity::INFORMATION => "info",
1078 DiagnosticSeverity::HINT => "hint",
1079 _ => continue,
1080 };
1081
1082 diagnostic_count += 1;
1083
1084 writeln!(
1085 &mut diagnostic_content,
1086 "{} at line {}: {}",
1087 severity,
1088 start_point.row + 1,
1089 entry.diagnostic.message
1090 )?;
1091 }
1092
1093 if !diagnostic_content.is_empty() {
1094 file_chunks.push(sweep_ai::FileChunk {
1095 file_path: format!("Diagnostics for {}", full_path.display()),
1096 start_line: 0,
1097 end_line: diagnostic_count,
1098 content: diagnostic_content,
1099 timestamp: None,
1100 });
1101 }
1102
1103 let request_body = sweep_ai::AutocompleteRequest {
1104 debug_info,
1105 repo_name,
1106 file_path: full_path.clone(),
1107 file_contents: text.clone(),
1108 original_file_contents: text,
1109 cursor_position: offset,
1110 recent_changes: recent_changes.clone(),
1111 changes_above_cursor: true,
1112 multiple_suggestions: false,
1113 branch: None,
1114 file_chunks,
1115 retrieval_chunks: vec![],
1116 recent_user_actions: vec![],
1117 // TODO
1118 privacy_mode_enabled: false,
1119 };
1120
1121 let mut buf: Vec<u8> = Vec::new();
1122 let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
1123 serde_json::to_writer(writer, &request_body)?;
1124 let body: AsyncBody = buf.into();
1125
1126 const SWEEP_API_URL: &str =
1127 "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
1128
1129 let request = http_client::Request::builder()
1130 .uri(SWEEP_API_URL)
1131 .header("Content-Type", "application/json")
1132 .header("Authorization", format!("Bearer {}", api_token))
1133 .header("Connection", "keep-alive")
1134 .header("Content-Encoding", "br")
1135 .method(Method::POST)
1136 .body(body)?;
1137
1138 let mut response = http_client.send(request).await?;
1139
1140 let mut body: Vec<u8> = Vec::new();
1141 response.body_mut().read_to_end(&mut body).await?;
1142
1143 if !response.status().is_success() {
1144 anyhow::bail!(
1145 "Request failed with status: {:?}\nBody: {}",
1146 response.status(),
1147 String::from_utf8_lossy(&body),
1148 );
1149 };
1150
1151 let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
1152
1153 let old_text = snapshot
1154 .text_for_range(response.start_index..response.end_index)
1155 .collect::<String>();
1156 let edits = language::text_diff(&old_text, &response.completion)
1157 .into_iter()
1158 .map(|(range, text)| {
1159 (
1160 snapshot.anchor_after(response.start_index + range.start)
1161 ..snapshot.anchor_before(response.start_index + range.end),
1162 text,
1163 )
1164 })
1165 .collect::<Vec<_>>();
1166
1167 anyhow::Ok((response.autocomplete_id, edits, snapshot))
1168 }
1169 });
1170
1171 let buffer = active_buffer.clone();
1172 let project = project.clone();
1173 let active_buffer = active_buffer.clone();
1174
1175 cx.spawn(async move |this, cx| {
1176 let (id, edits, old_snapshot) = result.await?;
1177
1178 if edits.is_empty() {
1179 if has_events
1180 && allow_jump
1181 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1182 active_buffer,
1183 &snapshot,
1184 diagnostic_search_range,
1185 cursor_point,
1186 &project,
1187 cx,
1188 )
1189 .await?
1190 {
1191 return this
1192 .update(cx, |this, cx| {
1193 this.request_prediction_with_sweep(
1194 &project,
1195 &jump_buffer,
1196 jump_position,
1197 false,
1198 cx,
1199 )
1200 })?
1201 .await;
1202 }
1203
1204 return anyhow::Ok(None);
1205 }
1206
1207 let Some((edits, new_snapshot, preview_task)) =
1208 buffer.read_with(cx, |buffer, cx| {
1209 let new_snapshot = buffer.snapshot();
1210
1211 let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
1212 edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)?
1213 .into();
1214 let preview_task = buffer.preview_edits(edits.clone(), cx);
1215
1216 Some((edits, new_snapshot, preview_task))
1217 })?
1218 else {
1219 return anyhow::Ok(None);
1220 };
1221
1222 let prediction = EditPrediction {
1223 id: EditPredictionId(id.into()),
1224 edits,
1225 snapshot: new_snapshot,
1226 edit_preview: preview_task.await,
1227 buffer,
1228 };
1229
1230 anyhow::Ok(Some(prediction))
1231 })
1232 }
1233
1234 async fn next_diagnostic_location(
1235 active_buffer: Entity<Buffer>,
1236 active_buffer_snapshot: &BufferSnapshot,
1237 active_buffer_diagnostic_search_range: Range<Point>,
1238 active_buffer_cursor_point: Point,
1239 project: &Entity<Project>,
1240 cx: &mut AsyncApp,
1241 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1242 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1243 let mut jump_location = active_buffer_snapshot
1244 .diagnostic_groups(None)
1245 .into_iter()
1246 .filter_map(|(_, group)| {
1247 let range = &group.entries[group.primary_ix]
1248 .range
1249 .to_point(&active_buffer_snapshot);
1250 if range.overlaps(&active_buffer_diagnostic_search_range) {
1251 None
1252 } else {
1253 Some(range.start)
1254 }
1255 })
1256 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1257 .map(|position| {
1258 (
1259 active_buffer.clone(),
1260 active_buffer_snapshot.anchor_before(position),
1261 )
1262 });
1263
1264 if jump_location.is_none() {
1265 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1266 let file = buffer.file()?;
1267
1268 Some(ProjectPath {
1269 worktree_id: file.worktree_id(cx),
1270 path: file.path().clone(),
1271 })
1272 })?;
1273
1274 let buffer_task = project.update(cx, |project, cx| {
1275 let (path, _, _) = project
1276 .diagnostic_summaries(false, cx)
1277 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1278 .max_by_key(|(path, _, _)| {
1279 // find the buffer with errors that shares most parent directories
1280 path.path
1281 .components()
1282 .zip(
1283 active_buffer_path
1284 .as_ref()
1285 .map(|p| p.path.components())
1286 .unwrap_or_default(),
1287 )
1288 .take_while(|(a, b)| a == b)
1289 .count()
1290 })?;
1291
1292 Some(project.open_buffer(path, cx))
1293 })?;
1294
1295 if let Some(buffer_task) = buffer_task {
1296 let closest_buffer = buffer_task.await?;
1297
1298 jump_location = closest_buffer
1299 .read_with(cx, |buffer, _cx| {
1300 buffer
1301 .buffer_diagnostics(None)
1302 .into_iter()
1303 .min_by_key(|entry| entry.diagnostic.severity)
1304 .map(|entry| entry.range.start)
1305 })?
1306 .map(|position| (closest_buffer, position));
1307 }
1308 }
1309
1310 anyhow::Ok(jump_location)
1311 }
1312
1313 fn request_prediction_with_zed_cloud(
1314 &mut self,
1315 project: &Entity<Project>,
1316 active_buffer: &Entity<Buffer>,
1317 position: language::Anchor,
1318 cx: &mut Context<Self>,
1319 ) -> Task<Result<Option<EditPrediction>>> {
1320 let project_state = self.projects.get(&project.entity_id());
1321
1322 let index_state = project_state.and_then(|state| {
1323 state
1324 .syntax_index
1325 .as_ref()
1326 .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1327 });
1328 let options = self.options.clone();
1329 let active_snapshot = active_buffer.read(cx).snapshot();
1330 let Some(excerpt_path) = active_snapshot
1331 .file()
1332 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1333 else {
1334 return Task::ready(Err(anyhow!("No file path for excerpt")));
1335 };
1336 let client = self.client.clone();
1337 let llm_token = self.llm_token.clone();
1338 let app_version = AppVersion::global(cx);
1339 let worktree_snapshots = project
1340 .read(cx)
1341 .worktrees(cx)
1342 .map(|worktree| worktree.read(cx).snapshot())
1343 .collect::<Vec<_>>();
1344 let debug_tx = self.debug_tx.clone();
1345
1346 let events = project_state
1347 .map(|state| {
1348 state
1349 .events
1350 .iter()
1351 .filter_map(|event| event.to_request_event(cx))
1352 .collect::<Vec<_>>()
1353 })
1354 .unwrap_or_default();
1355
1356 let diagnostics = active_snapshot.diagnostic_sets().clone();
1357
1358 let parent_abs_path =
1359 project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
1360 let mut path = f.worktree.read(cx).absolutize(&f.path);
1361 if path.pop() { Some(path) } else { None }
1362 });
1363
1364 // TODO data collection
1365 let can_collect_data = cx.is_staff();
1366
1367 let empty_context_files = HashMap::default();
1368 let context_files = project_state
1369 .and_then(|project_state| project_state.context.as_ref())
1370 .unwrap_or(&empty_context_files);
1371
1372 #[cfg(feature = "eval-support")]
1373 let parsed_fut = futures::future::join_all(
1374 context_files
1375 .keys()
1376 .map(|buffer| buffer.read(cx).parsing_idle()),
1377 );
1378
1379 let mut included_files = context_files
1380 .iter()
1381 .filter_map(|(buffer_entity, ranges)| {
1382 let buffer = buffer_entity.read(cx);
1383 Some((
1384 buffer_entity.clone(),
1385 buffer.snapshot(),
1386 buffer.file()?.full_path(cx).into(),
1387 ranges.clone(),
1388 ))
1389 })
1390 .collect::<Vec<_>>();
1391
1392 included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1393 (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1394 });
1395
1396 #[cfg(feature = "eval-support")]
1397 let eval_cache = self.eval_cache.clone();
1398
1399 let request_task = cx.background_spawn({
1400 let active_buffer = active_buffer.clone();
1401 async move {
1402 #[cfg(feature = "eval-support")]
1403 parsed_fut.await;
1404
1405 let index_state = if let Some(index_state) = index_state {
1406 Some(index_state.lock_owned().await)
1407 } else {
1408 None
1409 };
1410
1411 let cursor_offset = position.to_offset(&active_snapshot);
1412 let cursor_point = cursor_offset.to_point(&active_snapshot);
1413
1414 let before_retrieval = chrono::Utc::now();
1415
1416 let (diagnostic_groups, diagnostic_groups_truncated) =
1417 Self::gather_nearby_diagnostics(
1418 cursor_offset,
1419 &diagnostics,
1420 &active_snapshot,
1421 options.max_diagnostic_bytes,
1422 );
1423
1424 let cloud_request = match options.context {
1425 ContextMode::Agentic(context_options) => {
1426 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1427 cursor_point,
1428 &active_snapshot,
1429 &context_options.excerpt,
1430 index_state.as_deref(),
1431 ) else {
1432 return Ok((None, None));
1433 };
1434
1435 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1436 ..active_snapshot.anchor_before(excerpt.range.end);
1437
1438 if let Some(buffer_ix) =
1439 included_files.iter().position(|(_, snapshot, _, _)| {
1440 snapshot.remote_id() == active_snapshot.remote_id()
1441 })
1442 {
1443 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1444 ranges.push(excerpt_anchor_range);
1445 retrieval_search::merge_anchor_ranges(ranges, buffer);
1446 let last_ix = included_files.len() - 1;
1447 included_files.swap(buffer_ix, last_ix);
1448 } else {
1449 included_files.push((
1450 active_buffer.clone(),
1451 active_snapshot.clone(),
1452 excerpt_path.clone(),
1453 vec![excerpt_anchor_range],
1454 ));
1455 }
1456
1457 let included_files = included_files
1458 .iter()
1459 .map(|(_, snapshot, path, ranges)| {
1460 let ranges = ranges
1461 .iter()
1462 .map(|range| {
1463 let point_range = range.to_point(&snapshot);
1464 Line(point_range.start.row)..Line(point_range.end.row)
1465 })
1466 .collect::<Vec<_>>();
1467 let excerpts = assemble_excerpts(&snapshot, ranges);
1468 predict_edits_v3::IncludedFile {
1469 path: path.clone(),
1470 max_row: Line(snapshot.max_point().row),
1471 excerpts,
1472 }
1473 })
1474 .collect::<Vec<_>>();
1475
1476 predict_edits_v3::PredictEditsRequest {
1477 excerpt_path,
1478 excerpt: String::new(),
1479 excerpt_line_range: Line(0)..Line(0),
1480 excerpt_range: 0..0,
1481 cursor_point: predict_edits_v3::Point {
1482 line: predict_edits_v3::Line(cursor_point.row),
1483 column: cursor_point.column,
1484 },
1485 included_files,
1486 referenced_declarations: vec![],
1487 events,
1488 can_collect_data,
1489 diagnostic_groups,
1490 diagnostic_groups_truncated,
1491 debug_info: debug_tx.is_some(),
1492 prompt_max_bytes: Some(options.max_prompt_bytes),
1493 prompt_format: options.prompt_format,
1494 // TODO [zeta2]
1495 signatures: vec![],
1496 excerpt_parent: None,
1497 git_info: None,
1498 }
1499 }
1500 ContextMode::Syntax(context_options) => {
1501 let Some(context) = EditPredictionContext::gather_context(
1502 cursor_point,
1503 &active_snapshot,
1504 parent_abs_path.as_deref(),
1505 &context_options,
1506 index_state.as_deref(),
1507 ) else {
1508 return Ok((None, None));
1509 };
1510
1511 make_syntax_context_cloud_request(
1512 excerpt_path,
1513 context,
1514 events,
1515 can_collect_data,
1516 diagnostic_groups,
1517 diagnostic_groups_truncated,
1518 None,
1519 debug_tx.is_some(),
1520 &worktree_snapshots,
1521 index_state.as_deref(),
1522 Some(options.max_prompt_bytes),
1523 options.prompt_format,
1524 )
1525 }
1526 };
1527
1528 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1529
1530 let retrieval_time = chrono::Utc::now() - before_retrieval;
1531
1532 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1533 let (response_tx, response_rx) = oneshot::channel();
1534
1535 debug_tx
1536 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1537 ZetaEditPredictionDebugInfo {
1538 request: cloud_request.clone(),
1539 retrieval_time,
1540 buffer: active_buffer.downgrade(),
1541 local_prompt: match prompt_result.as_ref() {
1542 Ok((prompt, _)) => Ok(prompt.clone()),
1543 Err(err) => Err(err.to_string()),
1544 },
1545 position,
1546 response_rx,
1547 },
1548 ))
1549 .ok();
1550 Some(response_tx)
1551 } else {
1552 None
1553 };
1554
1555 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1556 if let Some(debug_response_tx) = debug_response_tx {
1557 debug_response_tx
1558 .send((Err("Request skipped".to_string()), TimeDelta::zero()))
1559 .ok();
1560 }
1561 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1562 }
1563
1564 let (prompt, _) = prompt_result?;
1565 let request = open_ai::Request {
1566 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1567 messages: vec![open_ai::RequestMessage::User {
1568 content: open_ai::MessageContent::Plain(prompt),
1569 }],
1570 stream: false,
1571 max_completion_tokens: None,
1572 stop: Default::default(),
1573 temperature: 0.7,
1574 tool_choice: None,
1575 parallel_tool_calls: None,
1576 tools: vec![],
1577 prompt_cache_key: None,
1578 reasoning_effort: None,
1579 };
1580
1581 log::trace!("Sending edit prediction request");
1582
1583 let before_request = chrono::Utc::now();
1584 let response = Self::send_raw_llm_request(
1585 request,
1586 client,
1587 llm_token,
1588 app_version,
1589 #[cfg(feature = "eval-support")]
1590 eval_cache,
1591 #[cfg(feature = "eval-support")]
1592 EvalCacheEntryKind::Prediction,
1593 )
1594 .await;
1595 let request_time = chrono::Utc::now() - before_request;
1596
1597 log::trace!("Got edit prediction response");
1598
1599 if let Some(debug_response_tx) = debug_response_tx {
1600 debug_response_tx
1601 .send((
1602 response
1603 .as_ref()
1604 .map_err(|err| err.to_string())
1605 .map(|response| response.0.clone()),
1606 request_time,
1607 ))
1608 .ok();
1609 }
1610
1611 let (res, usage) = response?;
1612 let request_id = EditPredictionId(res.id.clone().into());
1613 let Some(mut output_text) = text_from_response(res) else {
1614 return Ok((None, usage));
1615 };
1616
1617 if output_text.contains(CURSOR_MARKER) {
1618 log::trace!("Stripping out {CURSOR_MARKER} from response");
1619 output_text = output_text.replace(CURSOR_MARKER, "");
1620 }
1621
1622 let get_buffer_from_context = |path: &Path| {
1623 included_files
1624 .iter()
1625 .find_map(|(_, buffer, probe_path, ranges)| {
1626 if probe_path.as_ref() == path {
1627 Some((buffer, ranges.as_slice()))
1628 } else {
1629 None
1630 }
1631 })
1632 };
1633
1634 let (edited_buffer_snapshot, edits) = match options.prompt_format {
1635 PromptFormat::NumLinesUniDiff => {
1636 // TODO: Implement parsing of multi-file diffs
1637 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1638 }
1639 PromptFormat::Minimal | PromptFormat::MinimalQwen => {
1640 if output_text.contains("--- a/\n+++ b/\nNo edits") {
1641 let edits = vec![];
1642 (&active_snapshot, edits)
1643 } else {
1644 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1645 }
1646 }
1647 PromptFormat::OldTextNewText => {
1648 crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1649 .await?
1650 }
1651 _ => {
1652 bail!("unsupported prompt format {}", options.prompt_format)
1653 }
1654 };
1655
1656 let edited_buffer = included_files
1657 .iter()
1658 .find_map(|(buffer, snapshot, _, _)| {
1659 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1660 Some(buffer.clone())
1661 } else {
1662 None
1663 }
1664 })
1665 .context("Failed to find buffer in included_buffers")?;
1666
1667 anyhow::Ok((
1668 Some((
1669 request_id,
1670 edited_buffer,
1671 edited_buffer_snapshot.clone(),
1672 edits,
1673 )),
1674 usage,
1675 ))
1676 }
1677 });
1678
1679 cx.spawn({
1680 async move |this, cx| {
1681 let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
1682 Self::handle_api_response(&this, request_task.await, cx)?
1683 else {
1684 return Ok(None);
1685 };
1686
1687 // TODO telemetry: duration, etc
1688 Ok(
1689 EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
1690 .await,
1691 )
1692 }
1693 })
1694 }
1695
1696 async fn send_raw_llm_request(
1697 request: open_ai::Request,
1698 client: Arc<Client>,
1699 llm_token: LlmApiToken,
1700 app_version: Version,
1701 #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1702 #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1703 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1704 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1705 http_client::Url::parse(&predict_edits_url)?
1706 } else {
1707 client
1708 .http_client()
1709 .build_zed_llm_url("/predict_edits/raw", &[])?
1710 };
1711
1712 #[cfg(feature = "eval-support")]
1713 let cache_key = if let Some(cache) = eval_cache {
1714 use collections::FxHasher;
1715 use std::hash::{Hash, Hasher};
1716
1717 let mut hasher = FxHasher::default();
1718 url.hash(&mut hasher);
1719 let request_str = serde_json::to_string_pretty(&request)?;
1720 request_str.hash(&mut hasher);
1721 let hash = hasher.finish();
1722
1723 let key = (eval_cache_kind, hash);
1724 if let Some(response_str) = cache.read(key) {
1725 return Ok((serde_json::from_str(&response_str)?, None));
1726 }
1727
1728 Some((cache, request_str, key))
1729 } else {
1730 None
1731 };
1732
1733 let (response, usage) = Self::send_api_request(
1734 |builder| {
1735 let req = builder
1736 .uri(url.as_ref())
1737 .body(serde_json::to_string(&request)?.into());
1738 Ok(req?)
1739 },
1740 client,
1741 llm_token,
1742 app_version,
1743 )
1744 .await?;
1745
1746 #[cfg(feature = "eval-support")]
1747 if let Some((cache, request, key)) = cache_key {
1748 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1749 }
1750
1751 Ok((response, usage))
1752 }
1753
1754 fn handle_api_response<T>(
1755 this: &WeakEntity<Self>,
1756 response: Result<(T, Option<EditPredictionUsage>)>,
1757 cx: &mut gpui::AsyncApp,
1758 ) -> Result<T> {
1759 match response {
1760 Ok((data, usage)) => {
1761 if let Some(usage) = usage {
1762 this.update(cx, |this, cx| {
1763 this.user_store.update(cx, |user_store, cx| {
1764 user_store.update_edit_prediction_usage(usage, cx);
1765 });
1766 })
1767 .ok();
1768 }
1769 Ok(data)
1770 }
1771 Err(err) => {
1772 if err.is::<ZedUpdateRequiredError>() {
1773 cx.update(|cx| {
1774 this.update(cx, |this, _cx| {
1775 this.update_required = true;
1776 })
1777 .ok();
1778
1779 let error_message: SharedString = err.to_string().into();
1780 show_app_notification(
1781 NotificationId::unique::<ZedUpdateRequiredError>(),
1782 cx,
1783 move |cx| {
1784 cx.new(|cx| {
1785 ErrorMessagePrompt::new(error_message.clone(), cx)
1786 .with_link_button("Update Zed", "https://zed.dev/releases")
1787 })
1788 },
1789 );
1790 })
1791 .ok();
1792 }
1793 Err(err)
1794 }
1795 }
1796 }
1797
1798 async fn send_api_request<Res>(
1799 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1800 client: Arc<Client>,
1801 llm_token: LlmApiToken,
1802 app_version: Version,
1803 ) -> Result<(Res, Option<EditPredictionUsage>)>
1804 where
1805 Res: DeserializeOwned,
1806 {
1807 let http_client = client.http_client();
1808 let mut token = llm_token.acquire(&client).await?;
1809 let mut did_retry = false;
1810
1811 loop {
1812 let request_builder = http_client::Request::builder().method(Method::POST);
1813
1814 let request = build(
1815 request_builder
1816 .header("Content-Type", "application/json")
1817 .header("Authorization", format!("Bearer {}", token))
1818 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1819 )?;
1820
1821 let mut response = http_client.send(request).await?;
1822
1823 if let Some(minimum_required_version) = response
1824 .headers()
1825 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1826 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1827 {
1828 anyhow::ensure!(
1829 app_version >= minimum_required_version,
1830 ZedUpdateRequiredError {
1831 minimum_version: minimum_required_version
1832 }
1833 );
1834 }
1835
1836 if response.status().is_success() {
1837 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1838
1839 let mut body = Vec::new();
1840 response.body_mut().read_to_end(&mut body).await?;
1841 return Ok((serde_json::from_slice(&body)?, usage));
1842 } else if !did_retry
1843 && response
1844 .headers()
1845 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1846 .is_some()
1847 {
1848 did_retry = true;
1849 token = llm_token.refresh(&client).await?;
1850 } else {
1851 let mut body = String::new();
1852 response.body_mut().read_to_string(&mut body).await?;
1853 anyhow::bail!(
1854 "Request failed with status: {:?}\nBody: {}",
1855 response.status(),
1856 body
1857 );
1858 }
1859 }
1860 }
1861
1862 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1863 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1864
1865 // Refresh the related excerpts when the user just beguns editing after
1866 // an idle period, and after they pause editing.
1867 fn refresh_context_if_needed(
1868 &mut self,
1869 project: &Entity<Project>,
1870 buffer: &Entity<language::Buffer>,
1871 cursor_position: language::Anchor,
1872 cx: &mut Context<Self>,
1873 ) {
1874 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1875 return;
1876 }
1877
1878 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1879 return;
1880 };
1881
1882 let now = Instant::now();
1883 let was_idle = zeta_project
1884 .refresh_context_timestamp
1885 .map_or(true, |timestamp| {
1886 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1887 });
1888 zeta_project.refresh_context_timestamp = Some(now);
1889 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1890 let buffer = buffer.clone();
1891 let project = project.clone();
1892 async move |this, cx| {
1893 if was_idle {
1894 log::debug!("refetching edit prediction context after idle");
1895 } else {
1896 cx.background_executor()
1897 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1898 .await;
1899 log::debug!("refetching edit prediction context after pause");
1900 }
1901 this.update(cx, |this, cx| {
1902 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1903
1904 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1905 zeta_project.refresh_context_task = Some(task.log_err());
1906 };
1907 })
1908 .ok()
1909 }
1910 }));
1911 }
1912
1913 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1914 // and avoid spawning more than one concurrent task.
1915 pub fn refresh_context(
1916 &mut self,
1917 project: Entity<Project>,
1918 buffer: Entity<language::Buffer>,
1919 cursor_position: language::Anchor,
1920 cx: &mut Context<Self>,
1921 ) -> Task<Result<()>> {
1922 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1923 return Task::ready(anyhow::Ok(()));
1924 };
1925
1926 let ContextMode::Agentic(options) = &self.options().context else {
1927 return Task::ready(anyhow::Ok(()));
1928 };
1929
1930 let snapshot = buffer.read(cx).snapshot();
1931 let cursor_point = cursor_position.to_point(&snapshot);
1932 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1933 cursor_point,
1934 &snapshot,
1935 &options.excerpt,
1936 None,
1937 ) else {
1938 return Task::ready(Ok(()));
1939 };
1940
1941 let app_version = AppVersion::global(cx);
1942 let client = self.client.clone();
1943 let llm_token = self.llm_token.clone();
1944 let debug_tx = self.debug_tx.clone();
1945 let current_file_path: Arc<Path> = snapshot
1946 .file()
1947 .map(|f| f.full_path(cx).into())
1948 .unwrap_or_else(|| Path::new("untitled").into());
1949
1950 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1951 predict_edits_v3::PlanContextRetrievalRequest {
1952 excerpt: cursor_excerpt.text(&snapshot).body,
1953 excerpt_path: current_file_path,
1954 excerpt_line_range: cursor_excerpt.line_range,
1955 cursor_file_max_row: Line(snapshot.max_point().row),
1956 events: zeta_project
1957 .events
1958 .iter()
1959 .filter_map(|ev| ev.to_request_event(cx))
1960 .collect(),
1961 },
1962 ) {
1963 Ok(prompt) => prompt,
1964 Err(err) => {
1965 return Task::ready(Err(err));
1966 }
1967 };
1968
1969 if let Some(debug_tx) = &debug_tx {
1970 debug_tx
1971 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1972 ZetaContextRetrievalStartedDebugInfo {
1973 project: project.clone(),
1974 timestamp: Instant::now(),
1975 search_prompt: prompt.clone(),
1976 },
1977 ))
1978 .ok();
1979 }
1980
1981 pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1982 let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1983 language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1984 );
1985
1986 let description = schema
1987 .get("description")
1988 .and_then(|description| description.as_str())
1989 .unwrap()
1990 .to_string();
1991
1992 (schema.into(), description)
1993 });
1994
1995 let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1996
1997 let request = open_ai::Request {
1998 model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
1999 messages: vec![open_ai::RequestMessage::User {
2000 content: open_ai::MessageContent::Plain(prompt),
2001 }],
2002 stream: false,
2003 max_completion_tokens: None,
2004 stop: Default::default(),
2005 temperature: 0.7,
2006 tool_choice: None,
2007 parallel_tool_calls: None,
2008 tools: vec![open_ai::ToolDefinition::Function {
2009 function: FunctionDefinition {
2010 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2011 description: Some(tool_description),
2012 parameters: Some(tool_schema),
2013 },
2014 }],
2015 prompt_cache_key: None,
2016 reasoning_effort: None,
2017 };
2018
2019 #[cfg(feature = "eval-support")]
2020 let eval_cache = self.eval_cache.clone();
2021
2022 cx.spawn(async move |this, cx| {
2023 log::trace!("Sending search planning request");
2024 let response = Self::send_raw_llm_request(
2025 request,
2026 client,
2027 llm_token,
2028 app_version,
2029 #[cfg(feature = "eval-support")]
2030 eval_cache.clone(),
2031 #[cfg(feature = "eval-support")]
2032 EvalCacheEntryKind::Context,
2033 )
2034 .await;
2035 let mut response = Self::handle_api_response(&this, response, cx)?;
2036 log::trace!("Got search planning response");
2037
2038 let choice = response
2039 .choices
2040 .pop()
2041 .context("No choices in retrieval response")?;
2042 let open_ai::RequestMessage::Assistant {
2043 content: _,
2044 tool_calls,
2045 } = choice.message
2046 else {
2047 anyhow::bail!("Retrieval response didn't include an assistant message");
2048 };
2049
2050 let mut queries: Vec<SearchToolQuery> = Vec::new();
2051 for tool_call in tool_calls {
2052 let open_ai::ToolCallContent::Function { function } = tool_call.content;
2053 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2054 log::warn!(
2055 "Context retrieval response tried to call an unknown tool: {}",
2056 function.name
2057 );
2058
2059 continue;
2060 }
2061
2062 let input: SearchToolInput = serde_json::from_str(&function.arguments)
2063 .with_context(|| format!("invalid search json {}", &function.arguments))?;
2064 queries.extend(input.queries);
2065 }
2066
2067 if let Some(debug_tx) = &debug_tx {
2068 debug_tx
2069 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2070 ZetaSearchQueryDebugInfo {
2071 project: project.clone(),
2072 timestamp: Instant::now(),
2073 search_queries: queries.clone(),
2074 },
2075 ))
2076 .ok();
2077 }
2078
2079 log::trace!("Running retrieval search: {queries:#?}");
2080
2081 let related_excerpts_result = retrieval_search::run_retrieval_searches(
2082 queries,
2083 project.clone(),
2084 #[cfg(feature = "eval-support")]
2085 eval_cache,
2086 cx,
2087 )
2088 .await;
2089
2090 log::trace!("Search queries executed");
2091
2092 if let Some(debug_tx) = &debug_tx {
2093 debug_tx
2094 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2095 ZetaContextRetrievalDebugInfo {
2096 project: project.clone(),
2097 timestamp: Instant::now(),
2098 },
2099 ))
2100 .ok();
2101 }
2102
2103 this.update(cx, |this, _cx| {
2104 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2105 return Ok(());
2106 };
2107 zeta_project.refresh_context_task.take();
2108 if let Some(debug_tx) = &this.debug_tx {
2109 debug_tx
2110 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2111 ZetaContextRetrievalDebugInfo {
2112 project,
2113 timestamp: Instant::now(),
2114 },
2115 ))
2116 .ok();
2117 }
2118 match related_excerpts_result {
2119 Ok(excerpts) => {
2120 zeta_project.context = Some(excerpts);
2121 Ok(())
2122 }
2123 Err(error) => Err(error),
2124 }
2125 })?
2126 })
2127 }
2128
2129 pub fn set_context(
2130 &mut self,
2131 project: Entity<Project>,
2132 context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2133 ) {
2134 if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2135 zeta_project.context = Some(context);
2136 }
2137 }
2138
2139 fn gather_nearby_diagnostics(
2140 cursor_offset: usize,
2141 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2142 snapshot: &BufferSnapshot,
2143 max_diagnostics_bytes: usize,
2144 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2145 // TODO: Could make this more efficient
2146 let mut diagnostic_groups = Vec::new();
2147 for (language_server_id, diagnostics) in diagnostic_sets {
2148 let mut groups = Vec::new();
2149 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2150 diagnostic_groups.extend(
2151 groups
2152 .into_iter()
2153 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2154 );
2155 }
2156
2157 // sort by proximity to cursor
2158 diagnostic_groups.sort_by_key(|group| {
2159 let range = &group.entries[group.primary_ix].range;
2160 if range.start >= cursor_offset {
2161 range.start - cursor_offset
2162 } else if cursor_offset >= range.end {
2163 cursor_offset - range.end
2164 } else {
2165 (cursor_offset - range.start).min(range.end - cursor_offset)
2166 }
2167 });
2168
2169 let mut results = Vec::new();
2170 let mut diagnostic_groups_truncated = false;
2171 let mut diagnostics_byte_count = 0;
2172 for group in diagnostic_groups {
2173 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2174 diagnostics_byte_count += raw_value.get().len();
2175 if diagnostics_byte_count > max_diagnostics_bytes {
2176 diagnostic_groups_truncated = true;
2177 break;
2178 }
2179 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2180 }
2181
2182 (results, diagnostic_groups_truncated)
2183 }
2184
2185 // TODO: Dedupe with similar code in request_prediction?
2186 pub fn cloud_request_for_zeta_cli(
2187 &mut self,
2188 project: &Entity<Project>,
2189 buffer: &Entity<Buffer>,
2190 position: language::Anchor,
2191 cx: &mut Context<Self>,
2192 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2193 let project_state = self.projects.get(&project.entity_id());
2194
2195 let index_state = project_state.and_then(|state| {
2196 state
2197 .syntax_index
2198 .as_ref()
2199 .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2200 });
2201 let options = self.options.clone();
2202 let snapshot = buffer.read(cx).snapshot();
2203 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2204 return Task::ready(Err(anyhow!("No file path for excerpt")));
2205 };
2206 let worktree_snapshots = project
2207 .read(cx)
2208 .worktrees(cx)
2209 .map(|worktree| worktree.read(cx).snapshot())
2210 .collect::<Vec<_>>();
2211
2212 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2213 let mut path = f.worktree.read(cx).absolutize(&f.path);
2214 if path.pop() { Some(path) } else { None }
2215 });
2216
2217 cx.background_spawn(async move {
2218 let index_state = if let Some(index_state) = index_state {
2219 Some(index_state.lock_owned().await)
2220 } else {
2221 None
2222 };
2223
2224 let cursor_point = position.to_point(&snapshot);
2225
2226 let debug_info = true;
2227 EditPredictionContext::gather_context(
2228 cursor_point,
2229 &snapshot,
2230 parent_abs_path.as_deref(),
2231 match &options.context {
2232 ContextMode::Agentic(_) => {
2233 // TODO
2234 panic!("Llm mode not supported in zeta cli yet");
2235 }
2236 ContextMode::Syntax(edit_prediction_context_options) => {
2237 edit_prediction_context_options
2238 }
2239 },
2240 index_state.as_deref(),
2241 )
2242 .context("Failed to select excerpt")
2243 .map(|context| {
2244 make_syntax_context_cloud_request(
2245 excerpt_path.into(),
2246 context,
2247 // TODO pass everything
2248 Vec::new(),
2249 false,
2250 Vec::new(),
2251 false,
2252 None,
2253 debug_info,
2254 &worktree_snapshots,
2255 index_state.as_deref(),
2256 Some(options.max_prompt_bytes),
2257 options.prompt_format,
2258 )
2259 })
2260 })
2261 }
2262
2263 pub fn wait_for_initial_indexing(
2264 &mut self,
2265 project: &Entity<Project>,
2266 cx: &mut Context<Self>,
2267 ) -> Task<Result<()>> {
2268 let zeta_project = self.get_or_init_zeta_project(project, cx);
2269 if let Some(syntax_index) = &zeta_project.syntax_index {
2270 syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2271 } else {
2272 Task::ready(Ok(()))
2273 }
2274 }
2275}
2276
2277pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2278 let choice = res.choices.pop()?;
2279 let output_text = match choice.message {
2280 open_ai::RequestMessage::Assistant {
2281 content: Some(open_ai::MessageContent::Plain(content)),
2282 ..
2283 } => content,
2284 open_ai::RequestMessage::Assistant {
2285 content: Some(open_ai::MessageContent::Multipart(mut content)),
2286 ..
2287 } => {
2288 if content.is_empty() {
2289 log::error!("No output from Baseten completion response");
2290 return None;
2291 }
2292
2293 match content.remove(0) {
2294 open_ai::MessagePart::Text { text } => text,
2295 open_ai::MessagePart::Image { .. } => {
2296 log::error!("Expected text, got an image");
2297 return None;
2298 }
2299 }
2300 }
2301 _ => {
2302 log::error!("Invalid response message: {:?}", choice.message);
2303 return None;
2304 }
2305 };
2306 Some(output_text)
2307}
2308
2309#[derive(Error, Debug)]
2310#[error(
2311 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2312)]
2313pub struct ZedUpdateRequiredError {
2314 minimum_version: Version,
2315}
2316
2317fn make_syntax_context_cloud_request(
2318 excerpt_path: Arc<Path>,
2319 context: EditPredictionContext,
2320 events: Vec<predict_edits_v3::Event>,
2321 can_collect_data: bool,
2322 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2323 diagnostic_groups_truncated: bool,
2324 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2325 debug_info: bool,
2326 worktrees: &Vec<worktree::Snapshot>,
2327 index_state: Option<&SyntaxIndexState>,
2328 prompt_max_bytes: Option<usize>,
2329 prompt_format: PromptFormat,
2330) -> predict_edits_v3::PredictEditsRequest {
2331 let mut signatures = Vec::new();
2332 let mut declaration_to_signature_index = HashMap::default();
2333 let mut referenced_declarations = Vec::new();
2334
2335 for snippet in context.declarations {
2336 let project_entry_id = snippet.declaration.project_entry_id();
2337 let Some(path) = worktrees.iter().find_map(|worktree| {
2338 worktree.entry_for_id(project_entry_id).map(|entry| {
2339 let mut full_path = RelPathBuf::new();
2340 full_path.push(worktree.root_name());
2341 full_path.push(&entry.path);
2342 full_path
2343 })
2344 }) else {
2345 continue;
2346 };
2347
2348 let parent_index = index_state.and_then(|index_state| {
2349 snippet.declaration.parent().and_then(|parent| {
2350 add_signature(
2351 parent,
2352 &mut declaration_to_signature_index,
2353 &mut signatures,
2354 index_state,
2355 )
2356 })
2357 });
2358
2359 let (text, text_is_truncated) = snippet.declaration.item_text();
2360 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2361 path: path.as_std_path().into(),
2362 text: text.into(),
2363 range: snippet.declaration.item_line_range(),
2364 text_is_truncated,
2365 signature_range: snippet.declaration.signature_range_in_item_text(),
2366 parent_index,
2367 signature_score: snippet.score(DeclarationStyle::Signature),
2368 declaration_score: snippet.score(DeclarationStyle::Declaration),
2369 score_components: snippet.components,
2370 });
2371 }
2372
2373 let excerpt_parent = index_state.and_then(|index_state| {
2374 context
2375 .excerpt
2376 .parent_declarations
2377 .last()
2378 .and_then(|(parent, _)| {
2379 add_signature(
2380 *parent,
2381 &mut declaration_to_signature_index,
2382 &mut signatures,
2383 index_state,
2384 )
2385 })
2386 });
2387
2388 predict_edits_v3::PredictEditsRequest {
2389 excerpt_path,
2390 excerpt: context.excerpt_text.body,
2391 excerpt_line_range: context.excerpt.line_range,
2392 excerpt_range: context.excerpt.range,
2393 cursor_point: predict_edits_v3::Point {
2394 line: predict_edits_v3::Line(context.cursor_point.row),
2395 column: context.cursor_point.column,
2396 },
2397 referenced_declarations,
2398 included_files: vec![],
2399 signatures,
2400 excerpt_parent,
2401 events,
2402 can_collect_data,
2403 diagnostic_groups,
2404 diagnostic_groups_truncated,
2405 git_info,
2406 debug_info,
2407 prompt_max_bytes,
2408 prompt_format,
2409 }
2410}
2411
2412fn add_signature(
2413 declaration_id: DeclarationId,
2414 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2415 signatures: &mut Vec<Signature>,
2416 index: &SyntaxIndexState,
2417) -> Option<usize> {
2418 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2419 return Some(*signature_index);
2420 }
2421 let Some(parent_declaration) = index.declaration(declaration_id) else {
2422 log::error!("bug: missing parent declaration");
2423 return None;
2424 };
2425 let parent_index = parent_declaration.parent().and_then(|parent| {
2426 add_signature(parent, declaration_to_signature_index, signatures, index)
2427 });
2428 let (text, text_is_truncated) = parent_declaration.signature_text();
2429 let signature_index = signatures.len();
2430 signatures.push(Signature {
2431 text: text.into(),
2432 text_is_truncated,
2433 parent_index,
2434 range: parent_declaration.signature_line_range(),
2435 });
2436 declaration_to_signature_index.insert(declaration_id, signature_index);
2437 Some(signature_index)
2438}
2439
2440#[cfg(feature = "eval-support")]
2441pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2442
2443#[cfg(feature = "eval-support")]
2444#[derive(Debug, Clone, Copy, PartialEq)]
2445pub enum EvalCacheEntryKind {
2446 Context,
2447 Search,
2448 Prediction,
2449}
2450
2451#[cfg(feature = "eval-support")]
2452impl std::fmt::Display for EvalCacheEntryKind {
2453 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2454 match self {
2455 EvalCacheEntryKind::Search => write!(f, "search"),
2456 EvalCacheEntryKind::Context => write!(f, "context"),
2457 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2458 }
2459 }
2460}
2461
2462#[cfg(feature = "eval-support")]
2463pub trait EvalCache: Send + Sync {
2464 fn read(&self, key: EvalCacheKey) -> Option<String>;
2465 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2466}
2467
2468#[cfg(test)]
2469mod tests {
2470 use std::{path::Path, sync::Arc};
2471
2472 use client::UserStore;
2473 use clock::FakeSystemClock;
2474 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2475 use futures::{
2476 AsyncReadExt, StreamExt,
2477 channel::{mpsc, oneshot},
2478 };
2479 use gpui::{
2480 Entity, TestAppContext,
2481 http_client::{FakeHttpClient, Response},
2482 prelude::*,
2483 };
2484 use indoc::indoc;
2485 use language::OffsetRangeExt as _;
2486 use open_ai::Usage;
2487 use pretty_assertions::{assert_eq, assert_matches};
2488 use project::{FakeFs, Project};
2489 use serde_json::json;
2490 use settings::SettingsStore;
2491 use util::path;
2492 use uuid::Uuid;
2493
2494 use crate::{BufferEditPrediction, Zeta};
2495
2496 #[gpui::test]
2497 async fn test_current_state(cx: &mut TestAppContext) {
2498 let (zeta, mut req_rx) = init_test(cx);
2499 let fs = FakeFs::new(cx.executor());
2500 fs.insert_tree(
2501 "/root",
2502 json!({
2503 "1.txt": "Hello!\nHow\nBye\n",
2504 "2.txt": "Hola!\nComo\nAdios\n"
2505 }),
2506 )
2507 .await;
2508 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2509
2510 zeta.update(cx, |zeta, cx| {
2511 zeta.register_project(&project, cx);
2512 });
2513
2514 let buffer1 = project
2515 .update(cx, |project, cx| {
2516 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2517 project.open_buffer(path, cx)
2518 })
2519 .await
2520 .unwrap();
2521 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2522 let position = snapshot1.anchor_before(language::Point::new(1, 3));
2523
2524 // Prediction for current file
2525
2526 zeta.update(cx, |zeta, cx| {
2527 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2528 });
2529 let (_request, respond_tx) = req_rx.next().await.unwrap();
2530
2531 respond_tx
2532 .send(model_response(indoc! {r"
2533 --- a/root/1.txt
2534 +++ b/root/1.txt
2535 @@ ... @@
2536 Hello!
2537 -How
2538 +How are you?
2539 Bye
2540 "}))
2541 .unwrap();
2542
2543 cx.run_until_parked();
2544
2545 zeta.read_with(cx, |zeta, cx| {
2546 let prediction = zeta
2547 .current_prediction_for_buffer(&buffer1, &project, cx)
2548 .unwrap();
2549 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2550 });
2551
2552 // Context refresh
2553 let refresh_task = zeta.update(cx, |zeta, cx| {
2554 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2555 });
2556 let (_request, respond_tx) = req_rx.next().await.unwrap();
2557 respond_tx
2558 .send(open_ai::Response {
2559 id: Uuid::new_v4().to_string(),
2560 object: "response".into(),
2561 created: 0,
2562 model: "model".into(),
2563 choices: vec![open_ai::Choice {
2564 index: 0,
2565 message: open_ai::RequestMessage::Assistant {
2566 content: None,
2567 tool_calls: vec![open_ai::ToolCall {
2568 id: "search".into(),
2569 content: open_ai::ToolCallContent::Function {
2570 function: open_ai::FunctionContent {
2571 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
2572 .to_string(),
2573 arguments: serde_json::to_string(&SearchToolInput {
2574 queries: Box::new([SearchToolQuery {
2575 glob: "root/2.txt".to_string(),
2576 syntax_node: vec![],
2577 content: Some(".".into()),
2578 }]),
2579 })
2580 .unwrap(),
2581 },
2582 },
2583 }],
2584 },
2585 finish_reason: None,
2586 }],
2587 usage: Usage {
2588 prompt_tokens: 0,
2589 completion_tokens: 0,
2590 total_tokens: 0,
2591 },
2592 })
2593 .unwrap();
2594 refresh_task.await.unwrap();
2595
2596 zeta.update(cx, |zeta, _cx| {
2597 zeta.discard_current_prediction(&project);
2598 });
2599
2600 // Prediction for another file
2601 zeta.update(cx, |zeta, cx| {
2602 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2603 });
2604 let (_request, respond_tx) = req_rx.next().await.unwrap();
2605 respond_tx
2606 .send(model_response(indoc! {r#"
2607 --- a/root/2.txt
2608 +++ b/root/2.txt
2609 Hola!
2610 -Como
2611 +Como estas?
2612 Adios
2613 "#}))
2614 .unwrap();
2615 cx.run_until_parked();
2616
2617 zeta.read_with(cx, |zeta, cx| {
2618 let prediction = zeta
2619 .current_prediction_for_buffer(&buffer1, &project, cx)
2620 .unwrap();
2621 assert_matches!(
2622 prediction,
2623 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
2624 );
2625 });
2626
2627 let buffer2 = project
2628 .update(cx, |project, cx| {
2629 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
2630 project.open_buffer(path, cx)
2631 })
2632 .await
2633 .unwrap();
2634
2635 zeta.read_with(cx, |zeta, cx| {
2636 let prediction = zeta
2637 .current_prediction_for_buffer(&buffer2, &project, cx)
2638 .unwrap();
2639 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2640 });
2641 }
2642
2643 #[gpui::test]
2644 async fn test_simple_request(cx: &mut TestAppContext) {
2645 let (zeta, mut req_rx) = init_test(cx);
2646 let fs = FakeFs::new(cx.executor());
2647 fs.insert_tree(
2648 "/root",
2649 json!({
2650 "foo.md": "Hello!\nHow\nBye\n"
2651 }),
2652 )
2653 .await;
2654 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2655
2656 let buffer = project
2657 .update(cx, |project, cx| {
2658 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2659 project.open_buffer(path, cx)
2660 })
2661 .await
2662 .unwrap();
2663 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2664 let position = snapshot.anchor_before(language::Point::new(1, 3));
2665
2666 let prediction_task = zeta.update(cx, |zeta, cx| {
2667 zeta.request_prediction(&project, &buffer, position, cx)
2668 });
2669
2670 let (_, respond_tx) = req_rx.next().await.unwrap();
2671
2672 // TODO Put back when we have a structured request again
2673 // assert_eq!(
2674 // request.excerpt_path.as_ref(),
2675 // Path::new(path!("root/foo.md"))
2676 // );
2677 // assert_eq!(
2678 // request.cursor_point,
2679 // Point {
2680 // line: Line(1),
2681 // column: 3
2682 // }
2683 // );
2684
2685 respond_tx
2686 .send(model_response(indoc! { r"
2687 --- a/root/foo.md
2688 +++ b/root/foo.md
2689 @@ ... @@
2690 Hello!
2691 -How
2692 +How are you?
2693 Bye
2694 "}))
2695 .unwrap();
2696
2697 let prediction = prediction_task.await.unwrap().unwrap();
2698
2699 assert_eq!(prediction.edits.len(), 1);
2700 assert_eq!(
2701 prediction.edits[0].0.to_point(&snapshot).start,
2702 language::Point::new(1, 3)
2703 );
2704 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2705 }
2706
2707 #[gpui::test]
2708 async fn test_request_events(cx: &mut TestAppContext) {
2709 let (zeta, mut req_rx) = init_test(cx);
2710 let fs = FakeFs::new(cx.executor());
2711 fs.insert_tree(
2712 "/root",
2713 json!({
2714 "foo.md": "Hello!\n\nBye\n"
2715 }),
2716 )
2717 .await;
2718 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2719
2720 let buffer = project
2721 .update(cx, |project, cx| {
2722 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2723 project.open_buffer(path, cx)
2724 })
2725 .await
2726 .unwrap();
2727
2728 zeta.update(cx, |zeta, cx| {
2729 zeta.register_buffer(&buffer, &project, cx);
2730 });
2731
2732 buffer.update(cx, |buffer, cx| {
2733 buffer.edit(vec![(7..7, "How")], None, cx);
2734 });
2735
2736 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2737 let position = snapshot.anchor_before(language::Point::new(1, 3));
2738
2739 let prediction_task = zeta.update(cx, |zeta, cx| {
2740 zeta.request_prediction(&project, &buffer, position, cx)
2741 });
2742
2743 let (request, respond_tx) = req_rx.next().await.unwrap();
2744
2745 let prompt = prompt_from_request(&request);
2746 assert!(
2747 prompt.contains(indoc! {"
2748 --- a/root/foo.md
2749 +++ b/root/foo.md
2750 @@ -1,3 +1,3 @@
2751 Hello!
2752 -
2753 +How
2754 Bye
2755 "}),
2756 "{prompt}"
2757 );
2758
2759 respond_tx
2760 .send(model_response(indoc! {r#"
2761 --- a/root/foo.md
2762 +++ b/root/foo.md
2763 @@ ... @@
2764 Hello!
2765 -How
2766 +How are you?
2767 Bye
2768 "#}))
2769 .unwrap();
2770
2771 let prediction = prediction_task.await.unwrap().unwrap();
2772
2773 assert_eq!(prediction.edits.len(), 1);
2774 assert_eq!(
2775 prediction.edits[0].0.to_point(&snapshot).start,
2776 language::Point::new(1, 3)
2777 );
2778 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2779 }
2780
2781 // Skipped until we start including diagnostics in prompt
2782 // #[gpui::test]
2783 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
2784 // let (zeta, mut req_rx) = init_test(cx);
2785 // let fs = FakeFs::new(cx.executor());
2786 // fs.insert_tree(
2787 // "/root",
2788 // json!({
2789 // "foo.md": "Hello!\nBye"
2790 // }),
2791 // )
2792 // .await;
2793 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2794
2795 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2796 // let diagnostic = lsp::Diagnostic {
2797 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2798 // severity: Some(lsp::DiagnosticSeverity::ERROR),
2799 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2800 // ..Default::default()
2801 // };
2802
2803 // project.update(cx, |project, cx| {
2804 // project.lsp_store().update(cx, |lsp_store, cx| {
2805 // // Create some diagnostics
2806 // lsp_store
2807 // .update_diagnostics(
2808 // LanguageServerId(0),
2809 // lsp::PublishDiagnosticsParams {
2810 // uri: path_to_buffer_uri.clone(),
2811 // diagnostics: vec![diagnostic],
2812 // version: None,
2813 // },
2814 // None,
2815 // language::DiagnosticSourceKind::Pushed,
2816 // &[],
2817 // cx,
2818 // )
2819 // .unwrap();
2820 // });
2821 // });
2822
2823 // let buffer = project
2824 // .update(cx, |project, cx| {
2825 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2826 // project.open_buffer(path, cx)
2827 // })
2828 // .await
2829 // .unwrap();
2830
2831 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2832 // let position = snapshot.anchor_before(language::Point::new(0, 0));
2833
2834 // let _prediction_task = zeta.update(cx, |zeta, cx| {
2835 // zeta.request_prediction(&project, &buffer, position, cx)
2836 // });
2837
2838 // let (request, _respond_tx) = req_rx.next().await.unwrap();
2839
2840 // assert_eq!(request.diagnostic_groups.len(), 1);
2841 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2842 // .unwrap();
2843 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2844 // assert_eq!(
2845 // value,
2846 // json!({
2847 // "entries": [{
2848 // "range": {
2849 // "start": 8,
2850 // "end": 10
2851 // },
2852 // "diagnostic": {
2853 // "source": null,
2854 // "code": null,
2855 // "code_description": null,
2856 // "severity": 1,
2857 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2858 // "markdown": null,
2859 // "group_id": 0,
2860 // "is_primary": true,
2861 // "is_disk_based": false,
2862 // "is_unnecessary": false,
2863 // "source_kind": "Pushed",
2864 // "data": null,
2865 // "underline": true
2866 // }
2867 // }],
2868 // "primary_ix": 0
2869 // })
2870 // );
2871 // }
2872
2873 fn model_response(text: &str) -> open_ai::Response {
2874 open_ai::Response {
2875 id: Uuid::new_v4().to_string(),
2876 object: "response".into(),
2877 created: 0,
2878 model: "model".into(),
2879 choices: vec![open_ai::Choice {
2880 index: 0,
2881 message: open_ai::RequestMessage::Assistant {
2882 content: Some(open_ai::MessageContent::Plain(text.to_string())),
2883 tool_calls: vec![],
2884 },
2885 finish_reason: None,
2886 }],
2887 usage: Usage {
2888 prompt_tokens: 0,
2889 completion_tokens: 0,
2890 total_tokens: 0,
2891 },
2892 }
2893 }
2894
2895 fn prompt_from_request(request: &open_ai::Request) -> &str {
2896 assert_eq!(request.messages.len(), 1);
2897 let open_ai::RequestMessage::User {
2898 content: open_ai::MessageContent::Plain(content),
2899 ..
2900 } = &request.messages[0]
2901 else {
2902 panic!(
2903 "Request does not have single user message of type Plain. {:#?}",
2904 request
2905 );
2906 };
2907 content
2908 }
2909
2910 fn init_test(
2911 cx: &mut TestAppContext,
2912 ) -> (
2913 Entity<Zeta>,
2914 mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2915 ) {
2916 cx.update(move |cx| {
2917 let settings_store = SettingsStore::test(cx);
2918 cx.set_global(settings_store);
2919 zlog::init_test();
2920
2921 let (req_tx, req_rx) = mpsc::unbounded();
2922
2923 let http_client = FakeHttpClient::create({
2924 move |req| {
2925 let uri = req.uri().path().to_string();
2926 let mut body = req.into_body();
2927 let req_tx = req_tx.clone();
2928 async move {
2929 let resp = match uri.as_str() {
2930 "/client/llm_tokens" => serde_json::to_string(&json!({
2931 "token": "test"
2932 }))
2933 .unwrap(),
2934 "/predict_edits/raw" => {
2935 let mut buf = Vec::new();
2936 body.read_to_end(&mut buf).await.ok();
2937 let req = serde_json::from_slice(&buf).unwrap();
2938
2939 let (res_tx, res_rx) = oneshot::channel();
2940 req_tx.unbounded_send((req, res_tx)).unwrap();
2941 serde_json::to_string(&res_rx.await?).unwrap()
2942 }
2943 _ => {
2944 panic!("Unexpected path: {}", uri)
2945 }
2946 };
2947
2948 Ok(Response::builder().body(resp.into()).unwrap())
2949 }
2950 }
2951 });
2952
2953 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2954 client.cloud_client().set_credentials(1, "test".into());
2955
2956 language_model::init(client.clone(), cx);
2957
2958 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2959 let zeta = Zeta::global(&client, &user_store, cx);
2960
2961 (zeta, req_rx)
2962 })
2963 }
2964}