1use crate::PromptFormat;
2use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample};
3use crate::headless::ZetaCliAppState;
4use crate::paths::{CACHE_DIR, LOGS_DIR};
5use ::serde::Serialize;
6use anyhow::{Result, anyhow};
7use clap::Args;
8use collections::HashMap;
9use gpui::http_client::Url;
10use language::{Anchor, Buffer, Point};
11// use cloud_llm_client::predict_edits_v3::PromptFormat;
12use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
13use futures::StreamExt as _;
14use gpui::{AppContext, AsyncApp, Entity};
15use project::Project;
16use serde::Deserialize;
17use std::cell::Cell;
18use std::fs;
19use std::io::Write;
20use std::ops::Range;
21use std::path::PathBuf;
22use std::sync::Arc;
23use std::sync::Mutex;
24use std::time::{Duration, Instant};
25use zeta2::LlmResponseCache;
26
27#[derive(Debug, Args)]
28pub struct PredictArguments {
29 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
30 prompt_format: PromptFormat,
31 #[arg(long)]
32 use_expected_context: bool,
33 #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
34 format: PredictionsOutputFormat,
35 example_path: PathBuf,
36 #[clap(long)]
37 skip_cache: bool,
38}
39
40#[derive(clap::ValueEnum, Debug, Clone)]
41pub enum PredictionsOutputFormat {
42 Json,
43 Md,
44 Diff,
45}
46
47pub async fn run_zeta2_predict(
48 args: PredictArguments,
49 app_state: &Arc<ZetaCliAppState>,
50 cx: &mut AsyncApp,
51) {
52 let example = NamedExample::load(args.example_path).unwrap();
53 let result = zeta2_predict(
54 example,
55 args.skip_cache,
56 args.prompt_format,
57 args.use_expected_context,
58 &app_state,
59 cx,
60 )
61 .await
62 .unwrap();
63 result.write(args.format, std::io::stdout()).unwrap();
64}
65
66thread_local! {
67 static AUTHENTICATED: Cell<bool> = const { Cell::new(false) };
68}
69
70pub async fn zeta2_predict(
71 example: NamedExample,
72 skip_cache: bool,
73 prompt_format: PromptFormat,
74 use_expected_context: bool,
75 app_state: &Arc<ZetaCliAppState>,
76 cx: &mut AsyncApp,
77) -> Result<PredictionDetails> {
78 fs::create_dir_all(&*LOGS_DIR)?;
79 let worktree_path = example.setup_worktree().await?;
80
81 if !AUTHENTICATED.get() {
82 AUTHENTICATED.set(true);
83
84 app_state
85 .client
86 .sign_in_with_optional_connect(true, cx)
87 .await?;
88 }
89
90 let project = cx.update(|cx| {
91 Project::local(
92 app_state.client.clone(),
93 app_state.node_runtime.clone(),
94 app_state.user_store.clone(),
95 app_state.languages.clone(),
96 app_state.fs.clone(),
97 None,
98 cx,
99 )
100 })?;
101
102 let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
103
104 let worktree = project
105 .update(cx, |project, cx| {
106 project.create_worktree(&worktree_path, true, cx)
107 })?
108 .await?;
109 worktree
110 .read_with(cx, |worktree, _cx| {
111 worktree.as_local().unwrap().scan_complete()
112 })?
113 .await;
114
115 let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
116
117 zeta.update(cx, |zeta, _cx| {
118 zeta.with_llm_response_cache(Arc::new(Cache { skip_cache }));
119 })?;
120
121 cx.subscribe(&buffer_store, {
122 let project = project.clone();
123 move |_, event, cx| match event {
124 project::buffer_store::BufferStoreEvent::BufferAdded(buffer) => {
125 zeta2::Zeta::try_global(cx)
126 .unwrap()
127 .update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
128 }
129 _ => {}
130 }
131 })?
132 .detach();
133
134 let _edited_buffers = example.apply_edit_history(&project, cx).await?;
135 let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
136
137 let result = Arc::new(Mutex::new(PredictionDetails::default()));
138 let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
139
140 let debug_task = cx.background_spawn({
141 let result = result.clone();
142 async move {
143 let mut start_time = None;
144 let mut search_queries_generated_at = None;
145 let mut search_queries_executed_at = None;
146 while let Some(event) = debug_rx.next().await {
147 match event {
148 zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
149 start_time = Some(info.timestamp);
150 fs::write(LOGS_DIR.join("search_prompt.md"), &info.search_prompt)?;
151 }
152 zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
153 search_queries_generated_at = Some(info.timestamp);
154 fs::write(
155 LOGS_DIR.join("search_queries.json"),
156 serde_json::to_string_pretty(&info.search_queries).unwrap(),
157 )?;
158 }
159 zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
160 search_queries_executed_at = Some(info.timestamp);
161 }
162 zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
163 zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
164 let prediction_started_at = Instant::now();
165 start_time.get_or_insert(prediction_started_at);
166 fs::write(
167 LOGS_DIR.join("prediction_prompt.md"),
168 &request.local_prompt.unwrap_or_default(),
169 )?;
170
171 {
172 let mut result = result.lock().unwrap();
173
174 for included_file in request.request.included_files {
175 let insertions =
176 vec![(request.request.cursor_point, CURSOR_MARKER)];
177 result.excerpts.extend(included_file.excerpts.iter().map(
178 |excerpt| ActualExcerpt {
179 path: included_file.path.components().skip(1).collect(),
180 text: String::from(excerpt.text.as_ref()),
181 },
182 ));
183 write_codeblock(
184 &included_file.path,
185 included_file.excerpts.iter(),
186 if included_file.path == request.request.excerpt_path {
187 &insertions
188 } else {
189 &[]
190 },
191 included_file.max_row,
192 false,
193 &mut result.excerpts_text,
194 );
195 }
196 }
197
198 let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
199 let response = zeta2::text_from_response(response).unwrap_or_default();
200 let prediction_finished_at = Instant::now();
201 fs::write(LOGS_DIR.join("prediction_response.md"), &response)?;
202
203 let mut result = result.lock().unwrap();
204
205 if !use_expected_context {
206 result.planning_search_time =
207 Some(search_queries_generated_at.unwrap() - start_time.unwrap());
208 result.running_search_time = Some(
209 search_queries_executed_at.unwrap()
210 - search_queries_generated_at.unwrap(),
211 );
212 }
213 result.prediction_time = prediction_finished_at - prediction_started_at;
214 result.total_time = prediction_finished_at - start_time.unwrap();
215
216 break;
217 }
218 }
219 }
220 anyhow::Ok(())
221 }
222 });
223
224 zeta.update(cx, |zeta, _cx| {
225 let mut options = zeta.options().clone();
226 options.prompt_format = prompt_format.into();
227 zeta.set_options(options);
228 })?;
229
230 if use_expected_context {
231 let context_excerpts_tasks = example
232 .example
233 .expected_context
234 .iter()
235 .flat_map(|section| {
236 section.alternatives[0].excerpts.iter().map(|excerpt| {
237 resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
238 })
239 })
240 .collect::<Vec<_>>();
241 let context_excerpts_vec = futures::future::try_join_all(context_excerpts_tasks).await?;
242
243 let mut context_excerpts = HashMap::default();
244 for (buffer, mut excerpts) in context_excerpts_vec {
245 context_excerpts
246 .entry(buffer)
247 .or_insert(Vec::new())
248 .append(&mut excerpts);
249 }
250
251 zeta.update(cx, |zeta, _cx| {
252 zeta.set_context(project.clone(), context_excerpts)
253 })?;
254 } else {
255 zeta.update(cx, |zeta, cx| {
256 zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
257 })?
258 .await?;
259 }
260
261 let prediction = zeta
262 .update(cx, |zeta, cx| {
263 zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
264 })?
265 .await?;
266
267 debug_task.await?;
268
269 let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
270 result.diff = prediction
271 .map(|prediction| {
272 let old_text = prediction.snapshot.text();
273 let new_text = prediction
274 .buffer
275 .update(cx, |buffer, cx| {
276 buffer.edit(prediction.edits.iter().cloned(), None, cx);
277 buffer.text()
278 })
279 .unwrap();
280 language::unified_diff(&old_text, &new_text)
281 })
282 .unwrap_or_default();
283
284 anyhow::Ok(result)
285}
286
287async fn resolve_context_entry(
288 project: Entity<Project>,
289 excerpt: ExpectedExcerpt,
290 mut cx: AsyncApp,
291) -> Result<(Entity<Buffer>, Vec<Range<Anchor>>)> {
292 let buffer = project
293 .update(&mut cx, |project, cx| {
294 let project_path = project.find_project_path(&excerpt.path, cx).unwrap();
295 project.open_buffer(project_path, cx)
296 })?
297 .await?;
298
299 let ranges = buffer.read_with(&mut cx, |buffer, _| {
300 let full_text = buffer.text();
301 let offset = full_text
302 .find(&excerpt.text)
303 .expect("Expected context not found");
304 let point = buffer.offset_to_point(offset);
305 excerpt
306 .required_lines
307 .iter()
308 .map(|line| {
309 let row = point.row + line.0;
310 let range = Point::new(row, 0)..Point::new(row + 1, 0);
311 buffer.anchor_after(range.start)..buffer.anchor_before(range.end)
312 })
313 .collect()
314 })?;
315
316 Ok((buffer, ranges))
317}
318
319struct Cache {
320 skip_cache: bool,
321}
322
323impl Cache {
324 fn path(key: u64) -> PathBuf {
325 CACHE_DIR.join(format!("{key:x}.json"))
326 }
327}
328
329impl LlmResponseCache for Cache {
330 fn get_key(&self, url: &Url, body: &str) -> u64 {
331 use collections::FxHasher;
332 use std::hash::{Hash, Hasher};
333
334 let mut hasher = FxHasher::default();
335 url.hash(&mut hasher);
336 body.hash(&mut hasher);
337 hasher.finish()
338 }
339
340 fn read_response(&self, key: u64) -> Option<String> {
341 let path = Cache::path(key);
342 if path.exists() {
343 if self.skip_cache {
344 log::info!("Skipping existing cached LLM response: {}", path.display());
345 None
346 } else {
347 log::info!("Using LLM response from cache: {}", path.display());
348 Some(fs::read_to_string(path).unwrap())
349 }
350 } else {
351 None
352 }
353 }
354
355 fn write_response(&self, key: u64, value: &str) {
356 fs::create_dir_all(&*CACHE_DIR).unwrap();
357
358 let path = Cache::path(key);
359 log::info!("Writing LLM response to cache: {}", path.display());
360 fs::write(path, value).unwrap();
361 }
362}
363
364#[derive(Clone, Debug, Default, Serialize, Deserialize)]
365pub struct PredictionDetails {
366 pub diff: String,
367 pub excerpts: Vec<ActualExcerpt>,
368 pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
369 pub planning_search_time: Option<Duration>,
370 pub running_search_time: Option<Duration>,
371 pub prediction_time: Duration,
372 pub total_time: Duration,
373}
374
375impl PredictionDetails {
376 pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
377 let formatted = match format {
378 PredictionsOutputFormat::Md => self.to_markdown(),
379 PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
380 PredictionsOutputFormat::Diff => self.diff.clone(),
381 };
382
383 Ok(out.write_all(formatted.as_bytes())?)
384 }
385
386 pub fn to_markdown(&self) -> String {
387 let inference_time = self.planning_search_time.unwrap_or_default() + self.prediction_time;
388
389 format!(
390 "## Excerpts\n\n\
391 {}\n\n\
392 ## Prediction\n\n\
393 {}\n\n\
394 ## Time\n\n\
395 Planning searches: {}ms\n\
396 Running searches: {}ms\n\
397 Making Prediction: {}ms\n\n\
398 -------------------\n\n\
399 Total: {}ms\n\
400 Inference: {}ms ({:.2}%)\n",
401 self.excerpts_text,
402 self.diff,
403 self.planning_search_time.unwrap_or_default().as_millis(),
404 self.running_search_time.unwrap_or_default().as_millis(),
405 self.prediction_time.as_millis(),
406 self.total_time.as_millis(),
407 inference_time.as_millis(),
408 (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
409 )
410 }
411}