predict.rs

  1use crate::example::{ActualExcerpt, NamedExample};
  2use crate::headless::ZetaCliAppState;
  3use crate::paths::LOGS_DIR;
  4use ::serde::Serialize;
  5use anyhow::{Context as _, Result, anyhow};
  6use clap::Args;
  7use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
  8use futures::StreamExt as _;
  9use gpui::AsyncApp;
 10use project::Project;
 11use serde::Deserialize;
 12use std::cell::Cell;
 13use std::fs;
 14use std::io::Write;
 15use std::path::PathBuf;
 16use std::sync::Arc;
 17use std::time::{Duration, Instant};
 18
 19#[derive(Debug, Args)]
 20pub struct PredictArguments {
 21    example_path: PathBuf,
 22    #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
 23    format: PredictionsOutputFormat,
 24}
 25
 26#[derive(clap::ValueEnum, Debug, Clone)]
 27pub enum PredictionsOutputFormat {
 28    Json,
 29    Md,
 30    Diff,
 31}
 32pub async fn run_zeta2_predict(
 33    args: PredictArguments,
 34    app_state: &Arc<ZetaCliAppState>,
 35    cx: &mut AsyncApp,
 36) {
 37    let example = NamedExample::load(args.example_path).unwrap();
 38    let result = zeta2_predict(example, &app_state, cx).await.unwrap();
 39    result.write(args.format, std::io::stdout()).unwrap();
 40}
 41
 42thread_local! {
 43    static AUTHENTICATED: Cell<bool> = const { Cell::new(false) };
 44}
 45
 46pub async fn zeta2_predict(
 47    example: NamedExample,
 48    app_state: &Arc<ZetaCliAppState>,
 49    cx: &mut AsyncApp,
 50) -> Result<PredictionDetails> {
 51    fs::create_dir_all(&*LOGS_DIR)?;
 52    let worktree_path = example.setup_worktree().await?;
 53
 54    if !AUTHENTICATED.get() {
 55        AUTHENTICATED.set(true);
 56
 57        app_state
 58            .client
 59            .sign_in_with_optional_connect(true, cx)
 60            .await?;
 61    }
 62
 63    let project = cx.update(|cx| {
 64        Project::local(
 65            app_state.client.clone(),
 66            app_state.node_runtime.clone(),
 67            app_state.user_store.clone(),
 68            app_state.languages.clone(),
 69            app_state.fs.clone(),
 70            None,
 71            cx,
 72        )
 73    })?;
 74
 75    let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
 76
 77    let worktree = project
 78        .update(cx, |project, cx| {
 79            project.create_worktree(&worktree_path, true, cx)
 80        })?
 81        .await?;
 82    worktree
 83        .read_with(cx, |worktree, _cx| {
 84            worktree.as_local().unwrap().scan_complete()
 85        })?
 86        .await;
 87
 88    let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
 89
 90    cx.subscribe(&buffer_store, {
 91        let project = project.clone();
 92        move |_, event, cx| match event {
 93            project::buffer_store::BufferStoreEvent::BufferAdded(buffer) => {
 94                zeta2::Zeta::try_global(cx)
 95                    .unwrap()
 96                    .update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
 97            }
 98            _ => {}
 99        }
100    })?
101    .detach();
102
103    let _edited_buffers = example.apply_edit_history(&project, cx).await?;
104    let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
105
106    let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
107
108    let refresh_task = zeta.update(cx, |zeta, cx| {
109        zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
110    })?;
111
112    let mut context_retrieval_started_at = None;
113    let mut context_retrieval_finished_at = None;
114    let mut search_queries_generated_at = None;
115    let mut search_queries_executed_at = None;
116    let mut prediction_started_at = None;
117    let mut prediction_finished_at = None;
118    let mut excerpts_text = String::new();
119    let mut prediction_task = None;
120    let mut result = PredictionDetails::default();
121    while let Some(event) = debug_rx.next().await {
122        match event {
123            zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
124                context_retrieval_started_at = Some(info.timestamp);
125                fs::write(LOGS_DIR.join("search_prompt.md"), &info.search_prompt)?;
126            }
127            zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
128                search_queries_generated_at = Some(info.timestamp);
129                fs::write(
130                    LOGS_DIR.join("search_queries.json"),
131                    serde_json::to_string_pretty(&info.regex_by_glob).unwrap(),
132                )?;
133            }
134            zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
135                search_queries_executed_at = Some(info.timestamp);
136            }
137            zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => {
138                context_retrieval_finished_at = Some(info.timestamp);
139
140                prediction_task = Some(zeta.update(cx, |zeta, cx| {
141                    zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
142                })?);
143            }
144            zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
145                prediction_started_at = Some(Instant::now());
146                fs::write(
147                    LOGS_DIR.join("prediction_prompt.md"),
148                    &request.local_prompt.unwrap_or_default(),
149                )?;
150
151                for included_file in request.request.included_files {
152                    let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)];
153                    result
154                        .excerpts
155                        .extend(included_file.excerpts.iter().map(|excerpt| ActualExcerpt {
156                            path: included_file.path.components().skip(1).collect(),
157                            text: String::from(excerpt.text.as_ref()),
158                        }));
159                    write_codeblock(
160                        &included_file.path,
161                        included_file.excerpts.iter(),
162                        if included_file.path == request.request.excerpt_path {
163                            &insertions
164                        } else {
165                            &[]
166                        },
167                        included_file.max_row,
168                        false,
169                        &mut excerpts_text,
170                    );
171                }
172
173                let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
174                let response = zeta2::text_from_response(response).unwrap_or_default();
175                prediction_finished_at = Some(Instant::now());
176                fs::write(LOGS_DIR.join("prediction_response.md"), &response)?;
177
178                break;
179            }
180        }
181    }
182
183    refresh_task.await.context("context retrieval failed")?;
184    let prediction = prediction_task.unwrap().await?;
185
186    result.diff = prediction
187        .map(|prediction| {
188            let old_text = prediction.snapshot.text();
189            let new_text = prediction.buffer.update(cx, |buffer, cx| {
190                buffer.edit(prediction.edits.iter().cloned(), None, cx);
191                buffer.text()
192            })?;
193            anyhow::Ok(language::unified_diff(&old_text, &new_text))
194        })
195        .transpose()?
196        .unwrap_or_default();
197    result.excerpts_text = excerpts_text;
198
199    result.planning_search_time =
200        search_queries_generated_at.unwrap() - context_retrieval_started_at.unwrap();
201    result.running_search_time =
202        search_queries_executed_at.unwrap() - search_queries_generated_at.unwrap();
203    result.filtering_search_time =
204        context_retrieval_finished_at.unwrap() - search_queries_executed_at.unwrap();
205    result.prediction_time = prediction_finished_at.unwrap() - prediction_started_at.unwrap();
206    result.total_time = prediction_finished_at.unwrap() - context_retrieval_started_at.unwrap();
207
208    anyhow::Ok(result)
209}
210
211#[derive(Debug, Default, Serialize, Deserialize)]
212pub struct PredictionDetails {
213    pub diff: String,
214    pub excerpts: Vec<ActualExcerpt>,
215    pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
216    pub planning_search_time: Duration,
217    pub filtering_search_time: Duration,
218    pub running_search_time: Duration,
219    pub prediction_time: Duration,
220    pub total_time: Duration,
221}
222
223impl PredictionDetails {
224    pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
225        let formatted = match format {
226            PredictionsOutputFormat::Md => self.to_markdown(),
227            PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
228            PredictionsOutputFormat::Diff => self.diff.clone(),
229        };
230
231        Ok(out.write_all(formatted.as_bytes())?)
232    }
233
234    pub fn to_markdown(&self) -> String {
235        let inference_time =
236            self.planning_search_time + self.filtering_search_time + self.prediction_time;
237
238        format!(
239            "## Excerpts\n\n\
240            {}\n\n\
241            ## Prediction\n\n\
242            {}\n\n\
243            ## Time\n\n\
244            Planning searches: {}ms\n\
245            Running searches: {}ms\n\
246            Filtering context results: {}ms\n\
247            Making Prediction: {}ms\n\n\
248            -------------------\n\n\
249            Total: {}ms\n\
250            Inference: {}ms ({:.2}%)\n",
251            self.excerpts_text,
252            self.diff,
253            self.planning_search_time.as_millis(),
254            self.running_search_time.as_millis(),
255            self.filtering_search_time.as_millis(),
256            self.prediction_time.as_millis(),
257            self.total_time.as_millis(),
258            inference_time.as_millis(),
259            (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
260        )
261    }
262}