1use crate::PromptFormat;
2use crate::example::{ActualExcerpt, NamedExample};
3use crate::headless::ZetaCliAppState;
4use crate::paths::LOGS_DIR;
5use ::serde::Serialize;
6use anyhow::{Result, anyhow};
7use clap::Args;
8// use cloud_llm_client::predict_edits_v3::PromptFormat;
9use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
10use futures::StreamExt as _;
11use gpui::{AppContext, AsyncApp};
12use project::Project;
13use serde::Deserialize;
14use std::cell::Cell;
15use std::fs;
16use std::io::Write;
17use std::path::PathBuf;
18use std::sync::Arc;
19use std::sync::Mutex;
20use std::time::{Duration, Instant};
21
22#[derive(Debug, Args)]
23pub struct PredictArguments {
24 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
25 prompt_format: PromptFormat,
26 #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
27 format: PredictionsOutputFormat,
28 example_path: PathBuf,
29}
30
31#[derive(clap::ValueEnum, Debug, Clone)]
32pub enum PredictionsOutputFormat {
33 Json,
34 Md,
35 Diff,
36}
37pub async fn run_zeta2_predict(
38 args: PredictArguments,
39 app_state: &Arc<ZetaCliAppState>,
40 cx: &mut AsyncApp,
41) {
42 let example = NamedExample::load(args.example_path).unwrap();
43 let result = zeta2_predict(example, args.prompt_format, &app_state, cx)
44 .await
45 .unwrap();
46 result.write(args.format, std::io::stdout()).unwrap();
47}
48
49thread_local! {
50 static AUTHENTICATED: Cell<bool> = const { Cell::new(false) };
51}
52
53pub async fn zeta2_predict(
54 example: NamedExample,
55 prompt_format: PromptFormat,
56 app_state: &Arc<ZetaCliAppState>,
57 cx: &mut AsyncApp,
58) -> Result<PredictionDetails> {
59 fs::create_dir_all(&*LOGS_DIR)?;
60 let worktree_path = example.setup_worktree().await?;
61
62 if !AUTHENTICATED.get() {
63 AUTHENTICATED.set(true);
64
65 app_state
66 .client
67 .sign_in_with_optional_connect(true, cx)
68 .await?;
69 }
70
71 let project = cx.update(|cx| {
72 Project::local(
73 app_state.client.clone(),
74 app_state.node_runtime.clone(),
75 app_state.user_store.clone(),
76 app_state.languages.clone(),
77 app_state.fs.clone(),
78 None,
79 cx,
80 )
81 })?;
82
83 let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
84
85 let worktree = project
86 .update(cx, |project, cx| {
87 project.create_worktree(&worktree_path, true, cx)
88 })?
89 .await?;
90 worktree
91 .read_with(cx, |worktree, _cx| {
92 worktree.as_local().unwrap().scan_complete()
93 })?
94 .await;
95
96 let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
97
98 cx.subscribe(&buffer_store, {
99 let project = project.clone();
100 move |_, event, cx| match event {
101 project::buffer_store::BufferStoreEvent::BufferAdded(buffer) => {
102 zeta2::Zeta::try_global(cx)
103 .unwrap()
104 .update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
105 }
106 _ => {}
107 }
108 })?
109 .detach();
110
111 let _edited_buffers = example.apply_edit_history(&project, cx).await?;
112 let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
113
114 let result = Arc::new(Mutex::new(PredictionDetails::default()));
115 let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
116
117 let debug_task = cx.background_spawn({
118 let result = result.clone();
119 async move {
120 let mut context_retrieval_started_at = None;
121 let mut context_retrieval_finished_at = None;
122 let mut search_queries_generated_at = None;
123 let mut search_queries_executed_at = None;
124 while let Some(event) = debug_rx.next().await {
125 match event {
126 zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
127 context_retrieval_started_at = Some(info.timestamp);
128 fs::write(LOGS_DIR.join("search_prompt.md"), &info.search_prompt)?;
129 }
130 zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
131 search_queries_generated_at = Some(info.timestamp);
132 fs::write(
133 LOGS_DIR.join("search_queries.json"),
134 serde_json::to_string_pretty(&info.search_queries).unwrap(),
135 )?;
136 }
137 zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
138 search_queries_executed_at = Some(info.timestamp);
139 }
140 zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => {
141 context_retrieval_finished_at = Some(info.timestamp);
142 }
143 zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
144 let prediction_started_at = Instant::now();
145 fs::write(
146 LOGS_DIR.join("prediction_prompt.md"),
147 &request.local_prompt.unwrap_or_default(),
148 )?;
149
150 {
151 let mut result = result.lock().unwrap();
152
153 for included_file in request.request.included_files {
154 let insertions =
155 vec![(request.request.cursor_point, CURSOR_MARKER)];
156 result.excerpts.extend(included_file.excerpts.iter().map(
157 |excerpt| ActualExcerpt {
158 path: included_file.path.components().skip(1).collect(),
159 text: String::from(excerpt.text.as_ref()),
160 },
161 ));
162 write_codeblock(
163 &included_file.path,
164 included_file.excerpts.iter(),
165 if included_file.path == request.request.excerpt_path {
166 &insertions
167 } else {
168 &[]
169 },
170 included_file.max_row,
171 false,
172 &mut result.excerpts_text,
173 );
174 }
175 }
176
177 let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
178 let response = zeta2::text_from_response(response).unwrap_or_default();
179 let prediction_finished_at = Instant::now();
180 fs::write(LOGS_DIR.join("prediction_response.md"), &response)?;
181
182 let mut result = result.lock().unwrap();
183
184 result.planning_search_time = search_queries_generated_at.unwrap()
185 - context_retrieval_started_at.unwrap();
186 result.running_search_time = search_queries_executed_at.unwrap()
187 - search_queries_generated_at.unwrap();
188 result.filtering_search_time = context_retrieval_finished_at.unwrap()
189 - search_queries_executed_at.unwrap();
190 result.prediction_time = prediction_finished_at - prediction_started_at;
191 result.total_time =
192 prediction_finished_at - context_retrieval_started_at.unwrap();
193
194 break;
195 }
196 }
197 }
198 anyhow::Ok(())
199 }
200 });
201
202 zeta.update(cx, |zeta, cx| {
203 let mut options = zeta.options().clone();
204 options.prompt_format = prompt_format.into();
205 zeta.set_options(options);
206 zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
207 })?
208 .await?;
209
210 let prediction = zeta
211 .update(cx, |zeta, cx| {
212 zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
213 })?
214 .await?;
215
216 debug_task.await?;
217
218 let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
219 result.diff = prediction
220 .map(|prediction| {
221 let old_text = prediction.snapshot.text();
222 let new_text = prediction
223 .buffer
224 .update(cx, |buffer, cx| {
225 buffer.edit(prediction.edits.iter().cloned(), None, cx);
226 buffer.text()
227 })
228 .unwrap();
229 language::unified_diff(&old_text, &new_text)
230 })
231 .unwrap_or_default();
232
233 anyhow::Ok(result)
234}
235
236#[derive(Clone, Debug, Default, Serialize, Deserialize)]
237pub struct PredictionDetails {
238 pub diff: String,
239 pub excerpts: Vec<ActualExcerpt>,
240 pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
241 pub planning_search_time: Duration,
242 pub filtering_search_time: Duration,
243 pub running_search_time: Duration,
244 pub prediction_time: Duration,
245 pub total_time: Duration,
246}
247
248impl PredictionDetails {
249 pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
250 let formatted = match format {
251 PredictionsOutputFormat::Md => self.to_markdown(),
252 PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
253 PredictionsOutputFormat::Diff => self.diff.clone(),
254 };
255
256 Ok(out.write_all(formatted.as_bytes())?)
257 }
258
259 pub fn to_markdown(&self) -> String {
260 let inference_time =
261 self.planning_search_time + self.filtering_search_time + self.prediction_time;
262
263 format!(
264 "## Excerpts\n\n\
265 {}\n\n\
266 ## Prediction\n\n\
267 {}\n\n\
268 ## Time\n\n\
269 Planning searches: {}ms\n\
270 Running searches: {}ms\n\
271 Filtering context results: {}ms\n\
272 Making Prediction: {}ms\n\n\
273 -------------------\n\n\
274 Total: {}ms\n\
275 Inference: {}ms ({:.2}%)\n",
276 self.excerpts_text,
277 self.diff,
278 self.planning_search_time.as_millis(),
279 self.running_search_time.as_millis(),
280 self.filtering_search_time.as_millis(),
281 self.prediction_time.as_millis(),
282 self.total_time.as_millis(),
283 inference_time.as_millis(),
284 (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
285 )
286 }
287}