1use crate::example::{ActualExcerpt, NamedExample};
2use crate::headless::ZetaCliAppState;
3use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
4use crate::{
5 CacheMode, PredictArguments, PredictionOptions, PredictionProvider, PredictionsOutputFormat,
6};
7use ::serde::Serialize;
8use anyhow::{Context, Result, anyhow};
9use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
10use edit_prediction::{EditPredictionStore, EvalCache, EvalCacheEntryKind, EvalCacheKey};
11use futures::StreamExt as _;
12use gpui::{AppContext, AsyncApp, Entity};
13use project::Project;
14use project::buffer_store::BufferStoreEvent;
15use serde::Deserialize;
16use std::fs;
17use std::io::{IsTerminal, Write};
18use std::path::PathBuf;
19use std::sync::Arc;
20use std::sync::Mutex;
21use std::time::{Duration, Instant};
22
23pub async fn run_predict(
24 args: PredictArguments,
25 app_state: &Arc<ZetaCliAppState>,
26 cx: &mut AsyncApp,
27) {
28 let example = NamedExample::load(args.example_path).unwrap();
29 let project = example.setup_project(app_state, cx).await.unwrap();
30 let store = setup_store(args.options.provider, &project, app_state, cx).unwrap();
31 let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
32 let result = perform_predict(example, project, store, None, args.options, cx)
33 .await
34 .unwrap();
35 result.write(args.format, std::io::stdout()).unwrap();
36
37 print_run_data_dir(true, std::io::stdout().is_terminal());
38}
39
40pub fn setup_store(
41 provider: PredictionProvider,
42 project: &Entity<Project>,
43 app_state: &Arc<ZetaCliAppState>,
44 cx: &mut AsyncApp,
45) -> Result<Entity<EditPredictionStore>> {
46 let store = cx.new(|cx| {
47 edit_prediction::EditPredictionStore::new(
48 app_state.client.clone(),
49 app_state.user_store.clone(),
50 cx,
51 )
52 })?;
53
54 store.update(cx, |store, _cx| {
55 let model = match provider {
56 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
57 PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
58 PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
59 };
60 store.set_edit_prediction_model(model);
61 })?;
62
63 let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
64
65 cx.subscribe(&buffer_store, {
66 let project = project.clone();
67 let store = store.clone();
68 move |_, event, cx| match event {
69 BufferStoreEvent::BufferAdded(buffer) => {
70 store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
71 }
72 _ => {}
73 }
74 })?
75 .detach();
76
77 anyhow::Ok(store)
78}
79
80pub async fn perform_predict(
81 example: NamedExample,
82 project: Entity<Project>,
83 store: Entity<EditPredictionStore>,
84 repetition_ix: Option<u16>,
85 options: PredictionOptions,
86 cx: &mut AsyncApp,
87) -> Result<PredictionDetails> {
88 let mut cache_mode = options.cache;
89 if repetition_ix.is_some() {
90 if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip {
91 panic!("Repetitions are not supported in Auto cache mode");
92 } else {
93 cache_mode = CacheMode::Skip;
94 }
95 } else if cache_mode == CacheMode::Auto {
96 cache_mode = CacheMode::Requests;
97 }
98
99 let mut example_run_dir = RUN_DIR.join(&example.file_name());
100 if let Some(repetition_ix) = repetition_ix {
101 example_run_dir = example_run_dir.join(format!("{:03}", repetition_ix));
102 }
103 fs::create_dir_all(&example_run_dir)?;
104 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
105 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
106 }
107
108 #[cfg(unix)]
109 std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
110 .context("creating latest link")?;
111
112 #[cfg(windows)]
113 std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
114 .context("creating latest link")?;
115
116 store.update(cx, |store, _cx| {
117 store.with_eval_cache(Arc::new(RunCache {
118 example_run_dir: example_run_dir.clone(),
119 cache_mode,
120 }));
121 })?;
122
123 let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
124
125 let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone())));
126
127 let prompt_format = options.zeta2.prompt_format;
128
129 store.update(cx, |store, _cx| {
130 let mut options = store.options().clone();
131 options.prompt_format = prompt_format.into();
132 store.set_options(options);
133 })?;
134
135 let mut debug_task = gpui::Task::ready(Ok(()));
136
137 if options.provider == crate::PredictionProvider::Zeta2 {
138 let mut debug_rx = store.update(cx, |store, _| store.debug_info())?;
139
140 debug_task = cx.background_spawn({
141 let result = result.clone();
142 async move {
143 let mut start_time = None;
144 let mut retrieval_finished_at = None;
145 while let Some(event) = debug_rx.next().await {
146 match event {
147 edit_prediction::DebugEvent::ContextRetrievalStarted(info) => {
148 start_time = Some(info.timestamp);
149 fs::write(
150 example_run_dir.join("search_prompt.md"),
151 &info.search_prompt,
152 )?;
153 }
154 edit_prediction::DebugEvent::ContextRetrievalFinished(info) => {
155 retrieval_finished_at = Some(info.timestamp);
156 for (key, value) in &info.metadata {
157 if *key == "search_queries" {
158 fs::write(
159 example_run_dir.join("search_queries.json"),
160 value.as_bytes(),
161 )?;
162 }
163 }
164 }
165 edit_prediction::DebugEvent::EditPredictionRequested(request) => {
166 let prediction_started_at = Instant::now();
167 start_time.get_or_insert(prediction_started_at);
168 let prompt = request.local_prompt.unwrap_or_default();
169 fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
170
171 {
172 let mut result = result.lock().unwrap();
173 result.prompt_len = prompt.chars().count();
174
175 for included_file in request.inputs.included_files {
176 let insertions =
177 vec![(request.inputs.cursor_point, CURSOR_MARKER)];
178 result.excerpts.extend(included_file.excerpts.iter().map(
179 |excerpt| ActualExcerpt {
180 path: included_file.path.components().skip(1).collect(),
181 text: String::from(excerpt.text.as_ref()),
182 },
183 ));
184 write_codeblock(
185 &included_file.path,
186 included_file.excerpts.iter(),
187 if included_file.path == request.inputs.cursor_path {
188 &insertions
189 } else {
190 &[]
191 },
192 included_file.max_row,
193 false,
194 &mut result.excerpts_text,
195 );
196 }
197 }
198
199 let response =
200 request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
201 let response = edit_prediction::zeta2::text_from_response(response)
202 .unwrap_or_default();
203 let prediction_finished_at = Instant::now();
204 fs::write(example_run_dir.join("prediction_response.md"), &response)?;
205
206 let mut result = result.lock().unwrap();
207 result.generated_len = response.chars().count();
208 result.retrieval_time =
209 retrieval_finished_at.unwrap() - start_time.unwrap();
210 result.prediction_time = prediction_finished_at - prediction_started_at;
211 result.total_time = prediction_finished_at - start_time.unwrap();
212
213 break;
214 }
215 }
216 }
217 anyhow::Ok(())
218 }
219 });
220
221 store.update(cx, |store, cx| {
222 store.refresh_context(&project, &cursor_buffer, cursor_anchor, cx)
223 })?;
224 }
225
226 let prediction = store
227 .update(cx, |store, cx| {
228 store.request_prediction(
229 &project,
230 &cursor_buffer,
231 cursor_anchor,
232 cloud_llm_client::PredictEditsRequestTrigger::Cli,
233 cx,
234 )
235 })?
236 .await?;
237
238 debug_task.await?;
239
240 let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
241
242 result.diff = prediction
243 .and_then(|prediction| {
244 let prediction = prediction.prediction.ok()?;
245 prediction.edit_preview.as_unified_diff(&prediction.edits)
246 })
247 .unwrap_or_default();
248
249 anyhow::Ok(result)
250}
251
252struct RunCache {
253 cache_mode: CacheMode,
254 example_run_dir: PathBuf,
255}
256
257impl RunCache {
258 fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
259 CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
260 }
261
262 fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
263 CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
264 }
265
266 fn link_to_run(&self, key: &EvalCacheKey) {
267 let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
268 fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
269
270 let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
271 fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
272 }
273}
274
275impl EvalCache for RunCache {
276 fn read(&self, key: EvalCacheKey) -> Option<String> {
277 let path = RunCache::output_cache_path(&key);
278
279 if path.exists() {
280 let use_cache = match key.0 {
281 EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
282 EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
283 self.cache_mode.use_cached_llm_responses()
284 }
285 };
286 if use_cache {
287 log::info!("Using cache entry: {}", path.display());
288 self.link_to_run(&key);
289 Some(fs::read_to_string(path).unwrap())
290 } else {
291 log::trace!("Skipping cached entry: {}", path.display());
292 None
293 }
294 } else if matches!(self.cache_mode, CacheMode::Force) {
295 panic!(
296 "No cached entry found for {:?}. Run without `--cache force` at least once.",
297 key.0
298 );
299 } else {
300 None
301 }
302 }
303
304 fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
305 fs::create_dir_all(&*CACHE_DIR).unwrap();
306
307 let input_path = RunCache::input_cache_path(&key);
308 fs::write(&input_path, input).unwrap();
309
310 let output_path = RunCache::output_cache_path(&key);
311 log::trace!("Writing cache entry: {}", output_path.display());
312 fs::write(&output_path, output).unwrap();
313
314 self.link_to_run(&key);
315 }
316}
317
318#[derive(Clone, Debug, Serialize, Deserialize)]
319pub struct PredictionDetails {
320 pub diff: String,
321 pub excerpts: Vec<ActualExcerpt>,
322 pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
323 pub retrieval_time: Duration,
324 pub prediction_time: Duration,
325 pub total_time: Duration,
326 pub run_example_dir: PathBuf,
327 pub prompt_len: usize,
328 pub generated_len: usize,
329}
330
331impl PredictionDetails {
332 pub fn new(run_example_dir: PathBuf) -> Self {
333 Self {
334 diff: Default::default(),
335 excerpts: Default::default(),
336 excerpts_text: Default::default(),
337 retrieval_time: Default::default(),
338 prediction_time: Default::default(),
339 total_time: Default::default(),
340 run_example_dir,
341 prompt_len: 0,
342 generated_len: 0,
343 }
344 }
345
346 pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
347 let formatted = match format {
348 PredictionsOutputFormat::Md => self.to_markdown(),
349 PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
350 PredictionsOutputFormat::Diff => self.diff.clone(),
351 };
352
353 Ok(out.write_all(formatted.as_bytes())?)
354 }
355
356 pub fn to_markdown(&self) -> String {
357 format!(
358 "## Excerpts\n\n\
359 {}\n\n\
360 ## Prediction\n\n\
361 {}\n\n\
362 ## Time\n\n\
363 Retrieval: {}ms\n\
364 Prediction: {}ms\n\n\
365 Total: {}ms\n",
366 self.excerpts_text,
367 self.diff,
368 self.retrieval_time.as_millis(),
369 self.prediction_time.as_millis(),
370 self.total_time.as_millis(),
371 )
372 }
373}