1use crate::example::{ActualExcerpt, NamedExample};
2use crate::headless::ZetaCliAppState;
3use crate::paths::LOGS_DIR;
4use ::serde::Serialize;
5use anyhow::{Context as _, Result, anyhow};
6use clap::Args;
7use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
8use futures::StreamExt as _;
9use gpui::AsyncApp;
10use project::Project;
11use serde::Deserialize;
12use std::cell::Cell;
13use std::fs;
14use std::io::Write;
15use std::path::PathBuf;
16use std::sync::Arc;
17use std::time::{Duration, Instant};
18
19#[derive(Debug, Args)]
20pub struct PredictArguments {
21 example_path: PathBuf,
22 #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
23 format: PredictionsOutputFormat,
24}
25
26#[derive(clap::ValueEnum, Debug, Clone)]
27pub enum PredictionsOutputFormat {
28 Json,
29 Md,
30 Diff,
31}
32pub async fn run_zeta2_predict(
33 args: PredictArguments,
34 app_state: &Arc<ZetaCliAppState>,
35 cx: &mut AsyncApp,
36) {
37 let example = NamedExample::load(args.example_path).unwrap();
38 let result = zeta2_predict(example, &app_state, cx).await.unwrap();
39 result.write(args.format, std::io::stdout()).unwrap();
40}
41
42thread_local! {
43 static AUTHENTICATED: Cell<bool> = const { Cell::new(false) };
44}
45
46pub async fn zeta2_predict(
47 example: NamedExample,
48 app_state: &Arc<ZetaCliAppState>,
49 cx: &mut AsyncApp,
50) -> Result<PredictionDetails> {
51 fs::create_dir_all(&*LOGS_DIR)?;
52 let worktree_path = example.setup_worktree().await?;
53
54 if !AUTHENTICATED.get() {
55 AUTHENTICATED.set(true);
56
57 app_state
58 .client
59 .sign_in_with_optional_connect(true, cx)
60 .await?;
61 }
62
63 let project = cx.update(|cx| {
64 Project::local(
65 app_state.client.clone(),
66 app_state.node_runtime.clone(),
67 app_state.user_store.clone(),
68 app_state.languages.clone(),
69 app_state.fs.clone(),
70 None,
71 cx,
72 )
73 })?;
74
75 let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
76
77 let worktree = project
78 .update(cx, |project, cx| {
79 project.create_worktree(&worktree_path, true, cx)
80 })?
81 .await?;
82 worktree
83 .read_with(cx, |worktree, _cx| {
84 worktree.as_local().unwrap().scan_complete()
85 })?
86 .await;
87
88 let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
89
90 cx.subscribe(&buffer_store, {
91 let project = project.clone();
92 move |_, event, cx| match event {
93 project::buffer_store::BufferStoreEvent::BufferAdded(buffer) => {
94 zeta2::Zeta::try_global(cx)
95 .unwrap()
96 .update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
97 }
98 _ => {}
99 }
100 })?
101 .detach();
102
103 let _edited_buffers = example.apply_edit_history(&project, cx).await?;
104 let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
105
106 let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
107
108 let refresh_task = zeta.update(cx, |zeta, cx| {
109 zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
110 })?;
111
112 let mut context_retrieval_started_at = None;
113 let mut context_retrieval_finished_at = None;
114 let mut search_queries_generated_at = None;
115 let mut search_queries_executed_at = None;
116 let mut prediction_started_at = None;
117 let mut prediction_finished_at = None;
118 let mut excerpts_text = String::new();
119 let mut prediction_task = None;
120 let mut result = PredictionDetails::default();
121 while let Some(event) = debug_rx.next().await {
122 match event {
123 zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
124 context_retrieval_started_at = Some(info.timestamp);
125 fs::write(LOGS_DIR.join("search_prompt.md"), &info.search_prompt)?;
126 }
127 zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
128 search_queries_generated_at = Some(info.timestamp);
129 fs::write(
130 LOGS_DIR.join("search_queries.json"),
131 serde_json::to_string_pretty(&info.regex_by_glob).unwrap(),
132 )?;
133 }
134 zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
135 search_queries_executed_at = Some(info.timestamp);
136 }
137 zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => {
138 context_retrieval_finished_at = Some(info.timestamp);
139
140 prediction_task = Some(zeta.update(cx, |zeta, cx| {
141 zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
142 })?);
143 }
144 zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
145 prediction_started_at = Some(Instant::now());
146 fs::write(
147 LOGS_DIR.join("prediction_prompt.md"),
148 &request.local_prompt.unwrap_or_default(),
149 )?;
150
151 for included_file in request.request.included_files {
152 let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)];
153 result
154 .excerpts
155 .extend(included_file.excerpts.iter().map(|excerpt| ActualExcerpt {
156 path: included_file.path.components().skip(1).collect(),
157 text: String::from(excerpt.text.as_ref()),
158 }));
159 write_codeblock(
160 &included_file.path,
161 included_file.excerpts.iter(),
162 if included_file.path == request.request.excerpt_path {
163 &insertions
164 } else {
165 &[]
166 },
167 included_file.max_row,
168 false,
169 &mut excerpts_text,
170 );
171 }
172
173 let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
174 let response = zeta2::text_from_response(response).unwrap_or_default();
175 prediction_finished_at = Some(Instant::now());
176 fs::write(LOGS_DIR.join("prediction_response.md"), &response)?;
177
178 break;
179 }
180 }
181 }
182
183 refresh_task.await.context("context retrieval failed")?;
184 let prediction = prediction_task.unwrap().await?;
185
186 result.diff = prediction
187 .map(|prediction| {
188 let old_text = prediction.snapshot.text();
189 let new_text = prediction.buffer.update(cx, |buffer, cx| {
190 buffer.edit(prediction.edits.iter().cloned(), None, cx);
191 buffer.text()
192 })?;
193 anyhow::Ok(language::unified_diff(&old_text, &new_text))
194 })
195 .transpose()?
196 .unwrap_or_default();
197 result.excerpts_text = excerpts_text;
198
199 result.planning_search_time =
200 search_queries_generated_at.unwrap() - context_retrieval_started_at.unwrap();
201 result.running_search_time =
202 search_queries_executed_at.unwrap() - search_queries_generated_at.unwrap();
203 result.filtering_search_time =
204 context_retrieval_finished_at.unwrap() - search_queries_executed_at.unwrap();
205 result.prediction_time = prediction_finished_at.unwrap() - prediction_started_at.unwrap();
206 result.total_time = prediction_finished_at.unwrap() - context_retrieval_started_at.unwrap();
207
208 anyhow::Ok(result)
209}
210
211#[derive(Debug, Default, Serialize, Deserialize)]
212pub struct PredictionDetails {
213 pub diff: String,
214 pub excerpts: Vec<ActualExcerpt>,
215 pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
216 pub planning_search_time: Duration,
217 pub filtering_search_time: Duration,
218 pub running_search_time: Duration,
219 pub prediction_time: Duration,
220 pub total_time: Duration,
221}
222
223impl PredictionDetails {
224 pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
225 let formatted = match format {
226 PredictionsOutputFormat::Md => self.to_markdown(),
227 PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
228 PredictionsOutputFormat::Diff => self.diff.clone(),
229 };
230
231 Ok(out.write_all(formatted.as_bytes())?)
232 }
233
234 pub fn to_markdown(&self) -> String {
235 let inference_time =
236 self.planning_search_time + self.filtering_search_time + self.prediction_time;
237
238 format!(
239 "## Excerpts\n\n\
240 {}\n\n\
241 ## Prediction\n\n\
242 {}\n\n\
243 ## Time\n\n\
244 Planning searches: {}ms\n\
245 Running searches: {}ms\n\
246 Filtering context results: {}ms\n\
247 Making Prediction: {}ms\n\n\
248 -------------------\n\n\
249 Total: {}ms\n\
250 Inference: {}ms ({:.2}%)\n",
251 self.excerpts_text,
252 self.diff,
253 self.planning_search_time.as_millis(),
254 self.running_search_time.as_millis(),
255 self.filtering_search_time.as_millis(),
256 self.prediction_time.as_millis(),
257 self.total_time.as_millis(),
258 inference_time.as_millis(),
259 (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
260 )
261 }
262}