1use crate::PromptFormat;
2use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample};
3use crate::headless::ZetaCliAppState;
4use crate::paths::{
5 CACHE_DIR, LOGS_DIR, LOGS_PREDICTION_PROMPT, LOGS_PREDICTION_RESPONSE, LOGS_SEARCH_PROMPT,
6 LOGS_SEARCH_QUERIES,
7};
8use ::serde::Serialize;
9use anyhow::{Result, anyhow};
10use clap::Args;
11use collections::HashMap;
12use gpui::http_client::Url;
13use language::{Anchor, Buffer, Point};
14// use cloud_llm_client::predict_edits_v3::PromptFormat;
15use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
16use futures::StreamExt as _;
17use gpui::{AppContext, AsyncApp, Entity};
18use project::Project;
19use serde::Deserialize;
20use std::cell::Cell;
21use std::fs;
22use std::io::Write;
23use std::ops::Range;
24use std::path::PathBuf;
25use std::sync::Arc;
26use std::sync::Mutex;
27use std::time::{Duration, Instant};
28use zeta2::LlmResponseCache;
29
30#[derive(Debug, Args)]
31pub struct PredictArguments {
32 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
33 prompt_format: PromptFormat,
34 #[arg(long)]
35 use_expected_context: bool,
36 #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
37 format: PredictionsOutputFormat,
38 example_path: PathBuf,
39 #[clap(long)]
40 skip_cache: bool,
41}
42
43#[derive(clap::ValueEnum, Debug, Clone)]
44pub enum PredictionsOutputFormat {
45 Json,
46 Md,
47 Diff,
48}
49
50pub async fn run_zeta2_predict(
51 args: PredictArguments,
52 app_state: &Arc<ZetaCliAppState>,
53 cx: &mut AsyncApp,
54) {
55 let example = NamedExample::load(args.example_path).unwrap();
56 let result = zeta2_predict(
57 example,
58 args.skip_cache,
59 args.prompt_format,
60 args.use_expected_context,
61 &app_state,
62 cx,
63 )
64 .await
65 .unwrap();
66 result.write(args.format, std::io::stdout()).unwrap();
67
68 println!("## Logs\n");
69 println!("Search prompt: {}", LOGS_SEARCH_PROMPT.display());
70 println!("Search queries: {}", LOGS_SEARCH_QUERIES.display());
71 println!("Prediction prompt: {}", LOGS_PREDICTION_PROMPT.display());
72 println!(
73 "Prediction response: {}",
74 LOGS_PREDICTION_RESPONSE.display()
75 );
76}
77
78thread_local! {
79 static AUTHENTICATED: Cell<bool> = const { Cell::new(false) };
80}
81
82pub async fn zeta2_predict(
83 example: NamedExample,
84 skip_cache: bool,
85 prompt_format: PromptFormat,
86 use_expected_context: bool,
87 app_state: &Arc<ZetaCliAppState>,
88 cx: &mut AsyncApp,
89) -> Result<PredictionDetails> {
90 fs::create_dir_all(&*LOGS_DIR)?;
91 let worktree_path = example.setup_worktree().await?;
92
93 if !AUTHENTICATED.get() {
94 AUTHENTICATED.set(true);
95
96 app_state
97 .client
98 .sign_in_with_optional_connect(true, cx)
99 .await?;
100 }
101
102 let project = cx.update(|cx| {
103 Project::local(
104 app_state.client.clone(),
105 app_state.node_runtime.clone(),
106 app_state.user_store.clone(),
107 app_state.languages.clone(),
108 app_state.fs.clone(),
109 None,
110 cx,
111 )
112 })?;
113
114 let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
115
116 let worktree = project
117 .update(cx, |project, cx| {
118 project.create_worktree(&worktree_path, true, cx)
119 })?
120 .await?;
121 worktree
122 .read_with(cx, |worktree, _cx| {
123 worktree.as_local().unwrap().scan_complete()
124 })?
125 .await;
126
127 let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
128
129 zeta.update(cx, |zeta, _cx| {
130 zeta.with_llm_response_cache(Arc::new(Cache { skip_cache }));
131 })?;
132
133 cx.subscribe(&buffer_store, {
134 let project = project.clone();
135 move |_, event, cx| match event {
136 project::buffer_store::BufferStoreEvent::BufferAdded(buffer) => {
137 zeta2::Zeta::try_global(cx)
138 .unwrap()
139 .update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
140 }
141 _ => {}
142 }
143 })?
144 .detach();
145
146 let _edited_buffers = example.apply_edit_history(&project, cx).await?;
147 let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
148
149 let result = Arc::new(Mutex::new(PredictionDetails::default()));
150 let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
151
152 let debug_task = cx.background_spawn({
153 let result = result.clone();
154 async move {
155 let mut start_time = None;
156 let mut search_queries_generated_at = None;
157 let mut search_queries_executed_at = None;
158 while let Some(event) = debug_rx.next().await {
159 match event {
160 zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
161 start_time = Some(info.timestamp);
162 fs::write(&*LOGS_SEARCH_PROMPT, &info.search_prompt)?;
163 }
164 zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
165 search_queries_generated_at = Some(info.timestamp);
166 fs::write(
167 &*LOGS_SEARCH_QUERIES,
168 serde_json::to_string_pretty(&info.search_queries).unwrap(),
169 )?;
170 }
171 zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
172 search_queries_executed_at = Some(info.timestamp);
173 }
174 zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
175 zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
176 let prediction_started_at = Instant::now();
177 start_time.get_or_insert(prediction_started_at);
178 fs::write(
179 &*LOGS_PREDICTION_PROMPT,
180 &request.local_prompt.unwrap_or_default(),
181 )?;
182
183 {
184 let mut result = result.lock().unwrap();
185
186 for included_file in request.request.included_files {
187 let insertions =
188 vec![(request.request.cursor_point, CURSOR_MARKER)];
189 result.excerpts.extend(included_file.excerpts.iter().map(
190 |excerpt| ActualExcerpt {
191 path: included_file.path.components().skip(1).collect(),
192 text: String::from(excerpt.text.as_ref()),
193 },
194 ));
195 write_codeblock(
196 &included_file.path,
197 included_file.excerpts.iter(),
198 if included_file.path == request.request.excerpt_path {
199 &insertions
200 } else {
201 &[]
202 },
203 included_file.max_row,
204 false,
205 &mut result.excerpts_text,
206 );
207 }
208 }
209
210 let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
211 let response = zeta2::text_from_response(response).unwrap_or_default();
212 let prediction_finished_at = Instant::now();
213 fs::write(&*LOGS_PREDICTION_RESPONSE, &response)?;
214
215 let mut result = result.lock().unwrap();
216
217 if !use_expected_context {
218 result.planning_search_time =
219 Some(search_queries_generated_at.unwrap() - start_time.unwrap());
220 result.running_search_time = Some(
221 search_queries_executed_at.unwrap()
222 - search_queries_generated_at.unwrap(),
223 );
224 }
225 result.prediction_time = prediction_finished_at - prediction_started_at;
226 result.total_time = prediction_finished_at - start_time.unwrap();
227
228 break;
229 }
230 }
231 }
232 anyhow::Ok(())
233 }
234 });
235
236 zeta.update(cx, |zeta, _cx| {
237 let mut options = zeta.options().clone();
238 options.prompt_format = prompt_format.into();
239 zeta.set_options(options);
240 })?;
241
242 if use_expected_context {
243 let context_excerpts_tasks = example
244 .example
245 .expected_context
246 .iter()
247 .flat_map(|section| {
248 section.alternatives[0].excerpts.iter().map(|excerpt| {
249 resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
250 })
251 })
252 .collect::<Vec<_>>();
253 let context_excerpts_vec = futures::future::try_join_all(context_excerpts_tasks).await?;
254
255 let mut context_excerpts = HashMap::default();
256 for (buffer, mut excerpts) in context_excerpts_vec {
257 context_excerpts
258 .entry(buffer)
259 .or_insert(Vec::new())
260 .append(&mut excerpts);
261 }
262
263 zeta.update(cx, |zeta, _cx| {
264 zeta.set_context(project.clone(), context_excerpts)
265 })?;
266 } else {
267 zeta.update(cx, |zeta, cx| {
268 zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
269 })?
270 .await?;
271 }
272
273 let prediction = zeta
274 .update(cx, |zeta, cx| {
275 zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
276 })?
277 .await?;
278
279 debug_task.await?;
280
281 let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
282 result.diff = prediction
283 .map(|prediction| {
284 let old_text = prediction.snapshot.text();
285 let new_text = prediction
286 .buffer
287 .update(cx, |buffer, cx| {
288 buffer.edit(prediction.edits.iter().cloned(), None, cx);
289 buffer.text()
290 })
291 .unwrap();
292 language::unified_diff(&old_text, &new_text)
293 })
294 .unwrap_or_default();
295
296 anyhow::Ok(result)
297}
298
299async fn resolve_context_entry(
300 project: Entity<Project>,
301 excerpt: ExpectedExcerpt,
302 mut cx: AsyncApp,
303) -> Result<(Entity<Buffer>, Vec<Range<Anchor>>)> {
304 let buffer = project
305 .update(&mut cx, |project, cx| {
306 let project_path = project.find_project_path(&excerpt.path, cx).unwrap();
307 project.open_buffer(project_path, cx)
308 })?
309 .await?;
310
311 let ranges = buffer.read_with(&mut cx, |buffer, _| {
312 let full_text = buffer.text();
313 let offset = full_text
314 .find(&excerpt.text)
315 .expect("Expected context not found");
316 let point = buffer.offset_to_point(offset);
317 excerpt
318 .required_lines
319 .iter()
320 .map(|line| {
321 let row = point.row + line.0;
322 let range = Point::new(row, 0)..Point::new(row + 1, 0);
323 buffer.anchor_after(range.start)..buffer.anchor_before(range.end)
324 })
325 .collect()
326 })?;
327
328 Ok((buffer, ranges))
329}
330
331struct Cache {
332 skip_cache: bool,
333}
334
335impl Cache {
336 fn path(key: u64) -> PathBuf {
337 CACHE_DIR.join(format!("{key:x}.json"))
338 }
339}
340
341impl LlmResponseCache for Cache {
342 fn get_key(&self, url: &Url, body: &str) -> u64 {
343 use collections::FxHasher;
344 use std::hash::{Hash, Hasher};
345
346 let mut hasher = FxHasher::default();
347 url.hash(&mut hasher);
348 body.hash(&mut hasher);
349 hasher.finish()
350 }
351
352 fn read_response(&self, key: u64) -> Option<String> {
353 let path = Cache::path(key);
354 if path.exists() {
355 if self.skip_cache {
356 log::info!("Skipping existing cached LLM response: {}", path.display());
357 None
358 } else {
359 log::info!("Using LLM response from cache: {}", path.display());
360 Some(fs::read_to_string(path).unwrap())
361 }
362 } else {
363 None
364 }
365 }
366
367 fn write_response(&self, key: u64, value: &str) {
368 fs::create_dir_all(&*CACHE_DIR).unwrap();
369
370 let path = Cache::path(key);
371 log::info!("Writing LLM response to cache: {}", path.display());
372 fs::write(path, value).unwrap();
373 }
374}
375
376#[derive(Clone, Debug, Default, Serialize, Deserialize)]
377pub struct PredictionDetails {
378 pub diff: String,
379 pub excerpts: Vec<ActualExcerpt>,
380 pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
381 pub planning_search_time: Option<Duration>,
382 pub running_search_time: Option<Duration>,
383 pub prediction_time: Duration,
384 pub total_time: Duration,
385}
386
387impl PredictionDetails {
388 pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
389 let formatted = match format {
390 PredictionsOutputFormat::Md => self.to_markdown(),
391 PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
392 PredictionsOutputFormat::Diff => self.diff.clone(),
393 };
394
395 Ok(out.write_all(formatted.as_bytes())?)
396 }
397
398 pub fn to_markdown(&self) -> String {
399 let inference_time = self.planning_search_time.unwrap_or_default() + self.prediction_time;
400
401 format!(
402 "## Excerpts\n\n\
403 {}\n\n\
404 ## Prediction\n\n\
405 {}\n\n\
406 ## Time\n\n\
407 Planning searches: {}ms\n\
408 Running searches: {}ms\n\
409 Making Prediction: {}ms\n\n\
410 -------------------\n\n\
411 Total: {}ms\n\
412 Inference: {}ms ({:.2}%)\n",
413 self.excerpts_text,
414 self.diff,
415 self.planning_search_time.unwrap_or_default().as_millis(),
416 self.running_search_time.unwrap_or_default().as_millis(),
417 self.prediction_time.as_millis(),
418 self.total_time.as_millis(),
419 inference_time.as_millis(),
420 (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
421 )
422 }
423}