1use crate::example::{ActualExcerpt, NamedExample};
2use crate::headless::ZetaCliAppState;
3use crate::paths::LOGS_DIR;
4use ::serde::Serialize;
5use anyhow::{Result, anyhow};
6use clap::Args;
7use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
8use futures::StreamExt as _;
9use gpui::{AppContext, AsyncApp};
10use project::Project;
11use serde::Deserialize;
12use std::cell::Cell;
13use std::fs;
14use std::io::Write;
15use std::path::PathBuf;
16use std::sync::Arc;
17use std::sync::Mutex;
18use std::time::{Duration, Instant};
19
20#[derive(Debug, Args)]
21pub struct PredictArguments {
22 example_path: PathBuf,
23 #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
24 format: PredictionsOutputFormat,
25}
26
27#[derive(clap::ValueEnum, Debug, Clone)]
28pub enum PredictionsOutputFormat {
29 Json,
30 Md,
31 Diff,
32}
33pub async fn run_zeta2_predict(
34 args: PredictArguments,
35 app_state: &Arc<ZetaCliAppState>,
36 cx: &mut AsyncApp,
37) {
38 let example = NamedExample::load(args.example_path).unwrap();
39 let result = zeta2_predict(example, &app_state, cx).await.unwrap();
40 result.write(args.format, std::io::stdout()).unwrap();
41}
42
43thread_local! {
44 static AUTHENTICATED: Cell<bool> = const { Cell::new(false) };
45}
46
47pub async fn zeta2_predict(
48 example: NamedExample,
49 app_state: &Arc<ZetaCliAppState>,
50 cx: &mut AsyncApp,
51) -> Result<PredictionDetails> {
52 fs::create_dir_all(&*LOGS_DIR)?;
53 let worktree_path = example.setup_worktree().await?;
54
55 if !AUTHENTICATED.get() {
56 AUTHENTICATED.set(true);
57
58 app_state
59 .client
60 .sign_in_with_optional_connect(true, cx)
61 .await?;
62 }
63
64 let project = cx.update(|cx| {
65 Project::local(
66 app_state.client.clone(),
67 app_state.node_runtime.clone(),
68 app_state.user_store.clone(),
69 app_state.languages.clone(),
70 app_state.fs.clone(),
71 None,
72 cx,
73 )
74 })?;
75
76 let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
77
78 let worktree = project
79 .update(cx, |project, cx| {
80 project.create_worktree(&worktree_path, true, cx)
81 })?
82 .await?;
83 worktree
84 .read_with(cx, |worktree, _cx| {
85 worktree.as_local().unwrap().scan_complete()
86 })?
87 .await;
88
89 let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
90
91 cx.subscribe(&buffer_store, {
92 let project = project.clone();
93 move |_, event, cx| match event {
94 project::buffer_store::BufferStoreEvent::BufferAdded(buffer) => {
95 zeta2::Zeta::try_global(cx)
96 .unwrap()
97 .update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
98 }
99 _ => {}
100 }
101 })?
102 .detach();
103
104 let _edited_buffers = example.apply_edit_history(&project, cx).await?;
105 let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
106
107 let result = Arc::new(Mutex::new(PredictionDetails::default()));
108 let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
109
110 let debug_task = cx.background_spawn({
111 let result = result.clone();
112 async move {
113 let mut context_retrieval_started_at = None;
114 let mut context_retrieval_finished_at = None;
115 let mut search_queries_generated_at = None;
116 let mut search_queries_executed_at = None;
117 while let Some(event) = debug_rx.next().await {
118 match event {
119 zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
120 context_retrieval_started_at = Some(info.timestamp);
121 fs::write(LOGS_DIR.join("search_prompt.md"), &info.search_prompt)?;
122 }
123 zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
124 search_queries_generated_at = Some(info.timestamp);
125 fs::write(
126 LOGS_DIR.join("search_queries.json"),
127 serde_json::to_string_pretty(&info.search_queries).unwrap(),
128 )?;
129 }
130 zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
131 search_queries_executed_at = Some(info.timestamp);
132 }
133 zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => {
134 context_retrieval_finished_at = Some(info.timestamp);
135 }
136 zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
137 let prediction_started_at = Instant::now();
138 fs::write(
139 LOGS_DIR.join("prediction_prompt.md"),
140 &request.local_prompt.unwrap_or_default(),
141 )?;
142
143 {
144 let mut result = result.lock().unwrap();
145
146 for included_file in request.request.included_files {
147 let insertions =
148 vec![(request.request.cursor_point, CURSOR_MARKER)];
149 result.excerpts.extend(included_file.excerpts.iter().map(
150 |excerpt| ActualExcerpt {
151 path: included_file.path.components().skip(1).collect(),
152 text: String::from(excerpt.text.as_ref()),
153 },
154 ));
155 write_codeblock(
156 &included_file.path,
157 included_file.excerpts.iter(),
158 if included_file.path == request.request.excerpt_path {
159 &insertions
160 } else {
161 &[]
162 },
163 included_file.max_row,
164 false,
165 &mut result.excerpts_text,
166 );
167 }
168 }
169
170 let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
171 let response = zeta2::text_from_response(response).unwrap_or_default();
172 let prediction_finished_at = Instant::now();
173 fs::write(LOGS_DIR.join("prediction_response.md"), &response)?;
174
175 let mut result = result.lock().unwrap();
176
177 result.planning_search_time = search_queries_generated_at.unwrap()
178 - context_retrieval_started_at.unwrap();
179 result.running_search_time = search_queries_executed_at.unwrap()
180 - search_queries_generated_at.unwrap();
181 result.filtering_search_time = context_retrieval_finished_at.unwrap()
182 - search_queries_executed_at.unwrap();
183 result.prediction_time = prediction_finished_at - prediction_started_at;
184 result.total_time =
185 prediction_finished_at - context_retrieval_started_at.unwrap();
186
187 break;
188 }
189 }
190 }
191 anyhow::Ok(())
192 }
193 });
194
195 zeta.update(cx, |zeta, cx| {
196 zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
197 })?
198 .await?;
199
200 let prediction = zeta
201 .update(cx, |zeta, cx| {
202 zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
203 })?
204 .await?;
205
206 debug_task.await?;
207
208 let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
209 result.diff = prediction
210 .map(|prediction| {
211 let old_text = prediction.snapshot.text();
212 let new_text = prediction
213 .buffer
214 .update(cx, |buffer, cx| {
215 buffer.edit(prediction.edits.iter().cloned(), None, cx);
216 buffer.text()
217 })
218 .unwrap();
219 language::unified_diff(&old_text, &new_text)
220 })
221 .unwrap_or_default();
222
223 anyhow::Ok(result)
224}
225
226#[derive(Clone, Debug, Default, Serialize, Deserialize)]
227pub struct PredictionDetails {
228 pub diff: String,
229 pub excerpts: Vec<ActualExcerpt>,
230 pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
231 pub planning_search_time: Duration,
232 pub filtering_search_time: Duration,
233 pub running_search_time: Duration,
234 pub prediction_time: Duration,
235 pub total_time: Duration,
236}
237
238impl PredictionDetails {
239 pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
240 let formatted = match format {
241 PredictionsOutputFormat::Md => self.to_markdown(),
242 PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
243 PredictionsOutputFormat::Diff => self.diff.clone(),
244 };
245
246 Ok(out.write_all(formatted.as_bytes())?)
247 }
248
249 pub fn to_markdown(&self) -> String {
250 let inference_time =
251 self.planning_search_time + self.filtering_search_time + self.prediction_time;
252
253 format!(
254 "## Excerpts\n\n\
255 {}\n\n\
256 ## Prediction\n\n\
257 {}\n\n\
258 ## Time\n\n\
259 Planning searches: {}ms\n\
260 Running searches: {}ms\n\
261 Filtering context results: {}ms\n\
262 Making Prediction: {}ms\n\n\
263 -------------------\n\n\
264 Total: {}ms\n\
265 Inference: {}ms ({:.2}%)\n",
266 self.excerpts_text,
267 self.diff,
268 self.planning_search_time.as_millis(),
269 self.running_search_time.as_millis(),
270 self.filtering_search_time.as_millis(),
271 self.prediction_time.as_millis(),
272 self.total_time.as_millis(),
273 inference_time.as_millis(),
274 (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
275 )
276 }
277}