1use crate::PromptFormat;
2use crate::example::{ActualExcerpt, NamedExample};
3use crate::headless::ZetaCliAppState;
4use crate::paths::{CACHE_DIR, LOGS_DIR};
5use ::serde::Serialize;
6use anyhow::{Result, anyhow};
7use clap::Args;
8use gpui::http_client::Url;
9// use cloud_llm_client::predict_edits_v3::PromptFormat;
10use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
11use futures::StreamExt as _;
12use gpui::{AppContext, AsyncApp};
13use project::Project;
14use serde::Deserialize;
15use std::cell::Cell;
16use std::fs;
17use std::io::Write;
18use std::path::PathBuf;
19use std::sync::Arc;
20use std::sync::Mutex;
21use std::time::{Duration, Instant};
22use zeta2::LlmResponseCache;
23
24#[derive(Debug, Args)]
25pub struct PredictArguments {
26 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
27 prompt_format: PromptFormat,
28 #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
29 format: PredictionsOutputFormat,
30 example_path: PathBuf,
31 #[clap(long)]
32 skip_cache: bool,
33}
34
35#[derive(clap::ValueEnum, Debug, Clone)]
36pub enum PredictionsOutputFormat {
37 Json,
38 Md,
39 Diff,
40}
41pub async fn run_zeta2_predict(
42 args: PredictArguments,
43 app_state: &Arc<ZetaCliAppState>,
44 cx: &mut AsyncApp,
45) {
46 let example = NamedExample::load(args.example_path).unwrap();
47 let result = zeta2_predict(example, args.skip_cache, args.prompt_format, &app_state, cx)
48 .await
49 .unwrap();
50 result.write(args.format, std::io::stdout()).unwrap();
51}
52
53thread_local! {
54 static AUTHENTICATED: Cell<bool> = const { Cell::new(false) };
55}
56
57pub async fn zeta2_predict(
58 example: NamedExample,
59 skip_cache: bool,
60 prompt_format: PromptFormat,
61 app_state: &Arc<ZetaCliAppState>,
62 cx: &mut AsyncApp,
63) -> Result<PredictionDetails> {
64 fs::create_dir_all(&*LOGS_DIR)?;
65 let worktree_path = example.setup_worktree().await?;
66
67 if !AUTHENTICATED.get() {
68 AUTHENTICATED.set(true);
69
70 app_state
71 .client
72 .sign_in_with_optional_connect(true, cx)
73 .await?;
74 }
75
76 let project = cx.update(|cx| {
77 Project::local(
78 app_state.client.clone(),
79 app_state.node_runtime.clone(),
80 app_state.user_store.clone(),
81 app_state.languages.clone(),
82 app_state.fs.clone(),
83 None,
84 cx,
85 )
86 })?;
87
88 let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
89
90 let worktree = project
91 .update(cx, |project, cx| {
92 project.create_worktree(&worktree_path, true, cx)
93 })?
94 .await?;
95 worktree
96 .read_with(cx, |worktree, _cx| {
97 worktree.as_local().unwrap().scan_complete()
98 })?
99 .await;
100
101 let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
102
103 zeta.update(cx, |zeta, _cx| {
104 zeta.with_llm_response_cache(Arc::new(Cache { skip_cache }));
105 })?;
106
107 cx.subscribe(&buffer_store, {
108 let project = project.clone();
109 move |_, event, cx| match event {
110 project::buffer_store::BufferStoreEvent::BufferAdded(buffer) => {
111 zeta2::Zeta::try_global(cx)
112 .unwrap()
113 .update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
114 }
115 _ => {}
116 }
117 })?
118 .detach();
119
120 let _edited_buffers = example.apply_edit_history(&project, cx).await?;
121 let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
122
123 let result = Arc::new(Mutex::new(PredictionDetails::default()));
124 let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
125
126 let debug_task = cx.background_spawn({
127 let result = result.clone();
128 async move {
129 let mut context_retrieval_started_at = None;
130 let mut context_retrieval_finished_at = None;
131 let mut search_queries_generated_at = None;
132 let mut search_queries_executed_at = None;
133 while let Some(event) = debug_rx.next().await {
134 match event {
135 zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
136 context_retrieval_started_at = Some(info.timestamp);
137 fs::write(LOGS_DIR.join("search_prompt.md"), &info.search_prompt)?;
138 }
139 zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
140 search_queries_generated_at = Some(info.timestamp);
141 fs::write(
142 LOGS_DIR.join("search_queries.json"),
143 serde_json::to_string_pretty(&info.search_queries).unwrap(),
144 )?;
145 }
146 zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
147 search_queries_executed_at = Some(info.timestamp);
148 }
149 zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => {
150 context_retrieval_finished_at = Some(info.timestamp);
151 }
152 zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
153 let prediction_started_at = Instant::now();
154 fs::write(
155 LOGS_DIR.join("prediction_prompt.md"),
156 &request.local_prompt.unwrap_or_default(),
157 )?;
158
159 {
160 let mut result = result.lock().unwrap();
161
162 for included_file in request.request.included_files {
163 let insertions =
164 vec![(request.request.cursor_point, CURSOR_MARKER)];
165 result.excerpts.extend(included_file.excerpts.iter().map(
166 |excerpt| ActualExcerpt {
167 path: included_file.path.components().skip(1).collect(),
168 text: String::from(excerpt.text.as_ref()),
169 },
170 ));
171 write_codeblock(
172 &included_file.path,
173 included_file.excerpts.iter(),
174 if included_file.path == request.request.excerpt_path {
175 &insertions
176 } else {
177 &[]
178 },
179 included_file.max_row,
180 false,
181 &mut result.excerpts_text,
182 );
183 }
184 }
185
186 let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
187 let response = zeta2::text_from_response(response).unwrap_or_default();
188 let prediction_finished_at = Instant::now();
189 fs::write(LOGS_DIR.join("prediction_response.md"), &response)?;
190
191 let mut result = result.lock().unwrap();
192
193 result.planning_search_time = search_queries_generated_at.unwrap()
194 - context_retrieval_started_at.unwrap();
195 result.running_search_time = search_queries_executed_at.unwrap()
196 - search_queries_generated_at.unwrap();
197 result.filtering_search_time = context_retrieval_finished_at.unwrap()
198 - search_queries_executed_at.unwrap();
199 result.prediction_time = prediction_finished_at - prediction_started_at;
200 result.total_time =
201 prediction_finished_at - context_retrieval_started_at.unwrap();
202
203 break;
204 }
205 }
206 }
207 anyhow::Ok(())
208 }
209 });
210
211 zeta.update(cx, |zeta, cx| {
212 let mut options = zeta.options().clone();
213 options.prompt_format = prompt_format.into();
214 zeta.set_options(options);
215 zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
216 })?
217 .await?;
218
219 let prediction = zeta
220 .update(cx, |zeta, cx| {
221 zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
222 })?
223 .await?;
224
225 debug_task.await?;
226
227 let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
228 result.diff = prediction
229 .map(|prediction| {
230 let old_text = prediction.snapshot.text();
231 let new_text = prediction
232 .buffer
233 .update(cx, |buffer, cx| {
234 buffer.edit(prediction.edits.iter().cloned(), None, cx);
235 buffer.text()
236 })
237 .unwrap();
238 language::unified_diff(&old_text, &new_text)
239 })
240 .unwrap_or_default();
241
242 anyhow::Ok(result)
243}
244
245struct Cache {
246 skip_cache: bool,
247}
248
249impl Cache {
250 fn path(key: u64) -> PathBuf {
251 CACHE_DIR.join(format!("{key:x}.json"))
252 }
253}
254
255impl LlmResponseCache for Cache {
256 fn get_key(&self, url: &Url, body: &str) -> u64 {
257 use collections::FxHasher;
258 use std::hash::{Hash, Hasher};
259
260 let mut hasher = FxHasher::default();
261 url.hash(&mut hasher);
262 body.hash(&mut hasher);
263 hasher.finish()
264 }
265
266 fn read_response(&self, key: u64) -> Option<String> {
267 let path = Cache::path(key);
268 if path.exists() {
269 if self.skip_cache {
270 log::info!("Skipping existing cached LLM response: {}", path.display());
271 None
272 } else {
273 log::info!("Using LLM response from cache: {}", path.display());
274 Some(fs::read_to_string(path).unwrap())
275 }
276 } else {
277 None
278 }
279 }
280
281 fn write_response(&self, key: u64, value: &str) {
282 fs::create_dir_all(&*CACHE_DIR).unwrap();
283
284 let path = Cache::path(key);
285 log::info!("Writing LLM response to cache: {}", path.display());
286 fs::write(path, value).unwrap();
287 }
288}
289
290#[derive(Clone, Debug, Default, Serialize, Deserialize)]
291pub struct PredictionDetails {
292 pub diff: String,
293 pub excerpts: Vec<ActualExcerpt>,
294 pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
295 pub planning_search_time: Duration,
296 pub filtering_search_time: Duration,
297 pub running_search_time: Duration,
298 pub prediction_time: Duration,
299 pub total_time: Duration,
300}
301
302impl PredictionDetails {
303 pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
304 let formatted = match format {
305 PredictionsOutputFormat::Md => self.to_markdown(),
306 PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
307 PredictionsOutputFormat::Diff => self.diff.clone(),
308 };
309
310 Ok(out.write_all(formatted.as_bytes())?)
311 }
312
313 pub fn to_markdown(&self) -> String {
314 let inference_time =
315 self.planning_search_time + self.filtering_search_time + self.prediction_time;
316
317 format!(
318 "## Excerpts\n\n\
319 {}\n\n\
320 ## Prediction\n\n\
321 {}\n\n\
322 ## Time\n\n\
323 Planning searches: {}ms\n\
324 Running searches: {}ms\n\
325 Filtering context results: {}ms\n\
326 Making Prediction: {}ms\n\n\
327 -------------------\n\n\
328 Total: {}ms\n\
329 Inference: {}ms ({:.2}%)\n",
330 self.excerpts_text,
331 self.diff,
332 self.planning_search_time.as_millis(),
333 self.running_search_time.as_millis(),
334 self.filtering_search_time.as_millis(),
335 self.prediction_time.as_millis(),
336 self.total_time.as_millis(),
337 inference_time.as_millis(),
338 (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
339 )
340 }
341}