predict.rs

  1use crate::PromptFormat;
  2use crate::example::{ActualExcerpt, NamedExample};
  3use crate::headless::ZetaCliAppState;
  4use crate::paths::LOGS_DIR;
  5use ::serde::Serialize;
  6use anyhow::{Result, anyhow};
  7use clap::Args;
  8// use cloud_llm_client::predict_edits_v3::PromptFormat;
  9use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
 10use futures::StreamExt as _;
 11use gpui::{AppContext, AsyncApp};
 12use project::Project;
 13use serde::Deserialize;
 14use std::cell::Cell;
 15use std::fs;
 16use std::io::Write;
 17use std::path::PathBuf;
 18use std::sync::Arc;
 19use std::sync::Mutex;
 20use std::time::{Duration, Instant};
 21
 22#[derive(Debug, Args)]
 23pub struct PredictArguments {
 24    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 25    prompt_format: PromptFormat,
 26    #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
 27    format: PredictionsOutputFormat,
 28    example_path: PathBuf,
 29}
 30
 31#[derive(clap::ValueEnum, Debug, Clone)]
 32pub enum PredictionsOutputFormat {
 33    Json,
 34    Md,
 35    Diff,
 36}
 37pub async fn run_zeta2_predict(
 38    args: PredictArguments,
 39    app_state: &Arc<ZetaCliAppState>,
 40    cx: &mut AsyncApp,
 41) {
 42    let example = NamedExample::load(args.example_path).unwrap();
 43    let result = zeta2_predict(example, args.prompt_format, &app_state, cx)
 44        .await
 45        .unwrap();
 46    result.write(args.format, std::io::stdout()).unwrap();
 47}
 48
 49thread_local! {
 50    static AUTHENTICATED: Cell<bool> = const { Cell::new(false) };
 51}
 52
 53pub async fn zeta2_predict(
 54    example: NamedExample,
 55    prompt_format: PromptFormat,
 56    app_state: &Arc<ZetaCliAppState>,
 57    cx: &mut AsyncApp,
 58) -> Result<PredictionDetails> {
 59    fs::create_dir_all(&*LOGS_DIR)?;
 60    let worktree_path = example.setup_worktree().await?;
 61
 62    if !AUTHENTICATED.get() {
 63        AUTHENTICATED.set(true);
 64
 65        app_state
 66            .client
 67            .sign_in_with_optional_connect(true, cx)
 68            .await?;
 69    }
 70
 71    let project = cx.update(|cx| {
 72        Project::local(
 73            app_state.client.clone(),
 74            app_state.node_runtime.clone(),
 75            app_state.user_store.clone(),
 76            app_state.languages.clone(),
 77            app_state.fs.clone(),
 78            None,
 79            cx,
 80        )
 81    })?;
 82
 83    let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
 84
 85    let worktree = project
 86        .update(cx, |project, cx| {
 87            project.create_worktree(&worktree_path, true, cx)
 88        })?
 89        .await?;
 90    worktree
 91        .read_with(cx, |worktree, _cx| {
 92            worktree.as_local().unwrap().scan_complete()
 93        })?
 94        .await;
 95
 96    let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
 97
 98    cx.subscribe(&buffer_store, {
 99        let project = project.clone();
100        move |_, event, cx| match event {
101            project::buffer_store::BufferStoreEvent::BufferAdded(buffer) => {
102                zeta2::Zeta::try_global(cx)
103                    .unwrap()
104                    .update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
105            }
106            _ => {}
107        }
108    })?
109    .detach();
110
111    let _edited_buffers = example.apply_edit_history(&project, cx).await?;
112    let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
113
114    let result = Arc::new(Mutex::new(PredictionDetails::default()));
115    let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
116
117    let debug_task = cx.background_spawn({
118        let result = result.clone();
119        async move {
120            let mut context_retrieval_started_at = None;
121            let mut context_retrieval_finished_at = None;
122            let mut search_queries_generated_at = None;
123            let mut search_queries_executed_at = None;
124            while let Some(event) = debug_rx.next().await {
125                match event {
126                    zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
127                        context_retrieval_started_at = Some(info.timestamp);
128                        fs::write(LOGS_DIR.join("search_prompt.md"), &info.search_prompt)?;
129                    }
130                    zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
131                        search_queries_generated_at = Some(info.timestamp);
132                        fs::write(
133                            LOGS_DIR.join("search_queries.json"),
134                            serde_json::to_string_pretty(&info.search_queries).unwrap(),
135                        )?;
136                    }
137                    zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
138                        search_queries_executed_at = Some(info.timestamp);
139                    }
140                    zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => {
141                        context_retrieval_finished_at = Some(info.timestamp);
142                    }
143                    zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
144                        let prediction_started_at = Instant::now();
145                        fs::write(
146                            LOGS_DIR.join("prediction_prompt.md"),
147                            &request.local_prompt.unwrap_or_default(),
148                        )?;
149
150                        {
151                            let mut result = result.lock().unwrap();
152
153                            for included_file in request.request.included_files {
154                                let insertions =
155                                    vec![(request.request.cursor_point, CURSOR_MARKER)];
156                                result.excerpts.extend(included_file.excerpts.iter().map(
157                                    |excerpt| ActualExcerpt {
158                                        path: included_file.path.components().skip(1).collect(),
159                                        text: String::from(excerpt.text.as_ref()),
160                                    },
161                                ));
162                                write_codeblock(
163                                    &included_file.path,
164                                    included_file.excerpts.iter(),
165                                    if included_file.path == request.request.excerpt_path {
166                                        &insertions
167                                    } else {
168                                        &[]
169                                    },
170                                    included_file.max_row,
171                                    false,
172                                    &mut result.excerpts_text,
173                                );
174                            }
175                        }
176
177                        let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
178                        let response = zeta2::text_from_response(response).unwrap_or_default();
179                        let prediction_finished_at = Instant::now();
180                        fs::write(LOGS_DIR.join("prediction_response.md"), &response)?;
181
182                        let mut result = result.lock().unwrap();
183
184                        result.planning_search_time = search_queries_generated_at.unwrap()
185                            - context_retrieval_started_at.unwrap();
186                        result.running_search_time = search_queries_executed_at.unwrap()
187                            - search_queries_generated_at.unwrap();
188                        result.filtering_search_time = context_retrieval_finished_at.unwrap()
189                            - search_queries_executed_at.unwrap();
190                        result.prediction_time = prediction_finished_at - prediction_started_at;
191                        result.total_time =
192                            prediction_finished_at - context_retrieval_started_at.unwrap();
193
194                        break;
195                    }
196                }
197            }
198            anyhow::Ok(())
199        }
200    });
201
202    zeta.update(cx, |zeta, cx| {
203        let mut options = zeta.options().clone();
204        options.prompt_format = prompt_format.into();
205        zeta.set_options(options);
206        zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
207    })?
208    .await?;
209
210    let prediction = zeta
211        .update(cx, |zeta, cx| {
212            zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
213        })?
214        .await?;
215
216    debug_task.await?;
217
218    let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
219    result.diff = prediction
220        .map(|prediction| {
221            let old_text = prediction.snapshot.text();
222            let new_text = prediction
223                .buffer
224                .update(cx, |buffer, cx| {
225                    buffer.edit(prediction.edits.iter().cloned(), None, cx);
226                    buffer.text()
227                })
228                .unwrap();
229            language::unified_diff(&old_text, &new_text)
230        })
231        .unwrap_or_default();
232
233    anyhow::Ok(result)
234}
235
236#[derive(Clone, Debug, Default, Serialize, Deserialize)]
237pub struct PredictionDetails {
238    pub diff: String,
239    pub excerpts: Vec<ActualExcerpt>,
240    pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
241    pub planning_search_time: Duration,
242    pub filtering_search_time: Duration,
243    pub running_search_time: Duration,
244    pub prediction_time: Duration,
245    pub total_time: Duration,
246}
247
248impl PredictionDetails {
249    pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
250        let formatted = match format {
251            PredictionsOutputFormat::Md => self.to_markdown(),
252            PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
253            PredictionsOutputFormat::Diff => self.diff.clone(),
254        };
255
256        Ok(out.write_all(formatted.as_bytes())?)
257    }
258
259    pub fn to_markdown(&self) -> String {
260        let inference_time =
261            self.planning_search_time + self.filtering_search_time + self.prediction_time;
262
263        format!(
264            "## Excerpts\n\n\
265            {}\n\n\
266            ## Prediction\n\n\
267            {}\n\n\
268            ## Time\n\n\
269            Planning searches: {}ms\n\
270            Running searches: {}ms\n\
271            Filtering context results: {}ms\n\
272            Making Prediction: {}ms\n\n\
273            -------------------\n\n\
274            Total: {}ms\n\
275            Inference: {}ms ({:.2}%)\n",
276            self.excerpts_text,
277            self.diff,
278            self.planning_search_time.as_millis(),
279            self.running_search_time.as_millis(),
280            self.filtering_search_time.as_millis(),
281            self.prediction_time.as_millis(),
282            self.total_time.as_millis(),
283            inference_time.as_millis(),
284            (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
285        )
286    }
287}