1use crate::example::{ActualExcerpt, NamedExample};
2
3use crate::headless::ZetaCliAppState;
4use ::serde::Serialize;
5use ::util::paths::PathStyle;
6use anyhow::{Context as _, Result, anyhow};
7use clap::Args;
8use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
9use futures::StreamExt as _;
10use gpui::AsyncApp;
11use language_model::LanguageModelRegistry;
12use project::{Project, ProjectPath};
13use serde::Deserialize;
14use std::cell::Cell;
15use std::io::Write;
16use std::path::PathBuf;
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19use util::rel_path::RelPath;
20
21#[derive(Debug, Args)]
22pub struct PredictArguments {
23 example_path: PathBuf,
24 #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
25 format: PredictionsOutputFormat,
26}
27
28#[derive(clap::ValueEnum, Debug, Clone)]
29pub enum PredictionsOutputFormat {
30 Json,
31 Md,
32 Diff,
33}
34pub async fn run_zeta2_predict(
35 args: PredictArguments,
36 app_state: &Arc<ZetaCliAppState>,
37 cx: &mut AsyncApp,
38) {
39 let example = NamedExample::load(args.example_path).unwrap();
40 let result = zeta2_predict(example, &app_state, cx).await.unwrap();
41 result.write(args.format, std::io::stdout()).unwrap();
42}
43
44thread_local! {
45 static AUTHENTICATED: Cell<bool> = const { Cell::new(false) };
46}
47
48pub async fn zeta2_predict(
49 example: NamedExample,
50 app_state: &Arc<ZetaCliAppState>,
51 cx: &mut AsyncApp,
52) -> Result<PredictionDetails> {
53 let worktree_path = example.setup_worktree().await?;
54
55 if !AUTHENTICATED.get() {
56 AUTHENTICATED.set(true);
57
58 cx.update(|cx| {
59 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
60 registry
61 .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID)
62 .unwrap()
63 .authenticate(cx)
64 })
65 })?
66 .await?;
67
68 app_state
69 .client
70 .sign_in_with_optional_connect(true, cx)
71 .await?;
72 }
73
74 let project = cx.update(|cx| {
75 Project::local(
76 app_state.client.clone(),
77 app_state.node_runtime.clone(),
78 app_state.user_store.clone(),
79 app_state.languages.clone(),
80 app_state.fs.clone(),
81 None,
82 cx,
83 )
84 })?;
85
86 let worktree = project
87 .update(cx, |project, cx| {
88 project.create_worktree(&worktree_path, true, cx)
89 })?
90 .await?;
91 worktree
92 .read_with(cx, |worktree, _cx| {
93 worktree.as_local().unwrap().scan_complete()
94 })?
95 .await;
96
97 let _edited_buffers = example.apply_edit_history(&project, cx).await?;
98
99 let cursor_path = RelPath::new(&example.example.cursor_path, PathStyle::Posix)?.into_arc();
100
101 let cursor_buffer = project
102 .update(cx, |project, cx| {
103 project.open_buffer(
104 ProjectPath {
105 worktree_id: worktree.read(cx).id(),
106 path: cursor_path,
107 },
108 cx,
109 )
110 })?
111 .await?;
112
113 let cursor_offset_within_excerpt = example
114 .example
115 .cursor_position
116 .find(CURSOR_MARKER)
117 .ok_or_else(|| anyhow!("missing cursor marker"))?;
118 let mut cursor_excerpt = example.example.cursor_position.clone();
119 cursor_excerpt.replace_range(
120 cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
121 "",
122 );
123 let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
124 let text = buffer.text();
125
126 let mut matches = text.match_indices(&cursor_excerpt);
127 let Some((excerpt_offset, _)) = matches.next() else {
128 anyhow::bail!(
129 "Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n"
130 );
131 };
132 assert!(matches.next().is_none());
133
134 Ok(excerpt_offset)
135 })??;
136
137 let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
138 let cursor_anchor =
139 cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
140
141 let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
142
143 let refresh_task = zeta.update(cx, |zeta, cx| {
144 zeta.register_buffer(&cursor_buffer, &project, cx);
145 zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
146 })?;
147
148 let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
149 let mut context_retrieval_started_at = None;
150 let mut context_retrieval_finished_at = None;
151 let mut search_queries_generated_at = None;
152 let mut search_queries_executed_at = None;
153 let mut prediction_started_at = None;
154 let mut prediction_finished_at = None;
155 let mut excerpts_text = String::new();
156 let mut prediction_task = None;
157 let mut result = PredictionDetails::default();
158 while let Some(event) = debug_rx.next().await {
159 match event {
160 zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
161 context_retrieval_started_at = Some(info.timestamp);
162 }
163 zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
164 search_queries_generated_at = Some(info.timestamp);
165 }
166 zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
167 search_queries_executed_at = Some(info.timestamp);
168 }
169 zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => {
170 context_retrieval_finished_at = Some(info.timestamp);
171
172 prediction_task = Some(zeta.update(cx, |zeta, cx| {
173 zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
174 })?);
175 }
176 zeta2::ZetaDebugInfo::EditPredicted(request) => {
177 prediction_started_at = Some(Instant::now());
178 request.response_rx.await?.map_err(|err| anyhow!(err))?;
179 prediction_finished_at = Some(Instant::now());
180
181 for included_file in request.request.included_files {
182 let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)];
183 result
184 .excerpts
185 .extend(included_file.excerpts.iter().map(|excerpt| ActualExcerpt {
186 path: included_file.path.components().skip(1).collect(),
187 text: String::from(excerpt.text.as_ref()),
188 }));
189 write_codeblock(
190 &included_file.path,
191 included_file.excerpts.iter(),
192 if included_file.path == request.request.excerpt_path {
193 &insertions
194 } else {
195 &[]
196 },
197 included_file.max_row,
198 false,
199 &mut excerpts_text,
200 );
201 }
202 break;
203 }
204 _ => {}
205 }
206 }
207
208 refresh_task.await.context("context retrieval failed")?;
209 let prediction = prediction_task.unwrap().await?;
210
211 result.diff = prediction
212 .map(|prediction| {
213 let old_text = prediction.snapshot.text();
214 let new_text = prediction.buffer.update(cx, |buffer, cx| {
215 buffer.edit(prediction.edits.iter().cloned(), None, cx);
216 buffer.text()
217 })?;
218 anyhow::Ok(language::unified_diff(&old_text, &new_text))
219 })
220 .transpose()?
221 .unwrap_or_default();
222 result.excerpts_text = excerpts_text;
223
224 result.planning_search_time =
225 search_queries_generated_at.unwrap() - context_retrieval_started_at.unwrap();
226 result.running_search_time =
227 search_queries_executed_at.unwrap() - search_queries_generated_at.unwrap();
228 result.filtering_search_time =
229 context_retrieval_finished_at.unwrap() - search_queries_executed_at.unwrap();
230 result.prediction_time = prediction_finished_at.unwrap() - prediction_started_at.unwrap();
231 result.total_time = prediction_finished_at.unwrap() - context_retrieval_started_at.unwrap();
232
233 anyhow::Ok(result)
234}
235
236#[derive(Debug, Default, Serialize, Deserialize)]
237pub struct PredictionDetails {
238 pub diff: String,
239 pub excerpts: Vec<ActualExcerpt>,
240 pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
241 pub planning_search_time: Duration,
242 pub filtering_search_time: Duration,
243 pub running_search_time: Duration,
244 pub prediction_time: Duration,
245 pub total_time: Duration,
246}
247
248impl PredictionDetails {
249 pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
250 let formatted = match format {
251 PredictionsOutputFormat::Md => self.to_markdown(),
252 PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
253 PredictionsOutputFormat::Diff => self.diff.clone(),
254 };
255
256 Ok(out.write_all(formatted.as_bytes())?)
257 }
258
259 pub fn to_markdown(&self) -> String {
260 let inference_time =
261 self.planning_search_time + self.filtering_search_time + self.prediction_time;
262
263 format!(
264 "## Excerpts\n\n\
265 {}\n\n\
266 ## Prediction\n\n\
267 {}\n\n\
268 ## Time\n\n\
269 Planning searches: {}ms\n\
270 Running searches: {}ms\n\
271 Filtering context results: {}ms\n\
272 Making Prediction: {}ms\n\n\
273 -------------------\n\n\
274 Total: {}ms\n\
275 Inference: {}ms ({:.2}%)\n",
276 self.excerpts_text,
277 self.diff,
278 self.planning_search_time.as_millis(),
279 self.running_search_time.as_millis(),
280 self.filtering_search_time.as_millis(),
281 self.prediction_time.as_millis(),
282 self.total_time.as_millis(),
283 inference_time.as_millis(),
284 (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
285 )
286 }
287}