predict.rs

  1use crate::{
  2    FormatPromptArgs, PredictArgs, PredictionProvider,
  3    anthropic_client::AnthropicClient,
  4    example::{Example, ExamplePrediction, ExamplePrompt},
  5    format_prompt::{TeacherPrompt, run_format_prompt},
  6    headless::EpAppState,
  7    load_project::run_load_project,
  8    paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
  9    progress::{ExampleProgress, InfoStyle, Step},
 10    retrieve_context::run_context_retrieval,
 11};
 12use anyhow::Context as _;
 13use edit_prediction::{DebugEvent, EditPredictionStore};
 14use futures::{FutureExt as _, StreamExt as _, future::Shared};
 15use gpui::{AppContext as _, AsyncApp, Task};
 16use std::{
 17    fs,
 18    sync::{
 19        Arc, Mutex, OnceLock,
 20        atomic::{AtomicUsize, Ordering::SeqCst},
 21    },
 22};
 23use zeta_prompt::ZetaVersion;
 24
 25static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
 26
 27pub async fn run_prediction(
 28    example: &mut Example,
 29    args: &PredictArgs,
 30    app_state: Arc<EpAppState>,
 31    example_progress: &ExampleProgress,
 32    mut cx: AsyncApp,
 33) -> anyhow::Result<()> {
 34    let provider = args.provider;
 35    let repetition_count = args.repetitions;
 36
 37    if let Some(existing_prediction) = example.predictions.first() {
 38        if existing_prediction.provider == provider {
 39            return Ok(());
 40        } else {
 41            example.predictions.clear();
 42        }
 43    }
 44
 45    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 46
 47    if let PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) =
 48        args.provider
 49    {
 50        let _step_progress = example_progress.start(Step::Predict);
 51
 52        run_format_prompt(
 53            example,
 54            &FormatPromptArgs { provider },
 55            app_state.clone(),
 56            example_progress,
 57            cx,
 58        )
 59        .await?;
 60
 61        let batched = matches!(provider, PredictionProvider::Teacher(..));
 62        return predict_anthropic(example, repetition_count, version, batched).await;
 63    }
 64
 65    run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
 66
 67    let step_progress = example_progress.start(Step::Predict);
 68
 69    if matches!(
 70        provider,
 71        PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
 72    ) {
 73        step_progress.set_substatus("authenticating");
 74        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
 75        AUTHENTICATED
 76            .get_or_init(|| {
 77                let client = app_state.client.clone();
 78                cx.spawn(async move |cx| {
 79                    if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
 80                        eprintln!("Authentication failed: {}", e);
 81                    }
 82                })
 83                .shared()
 84            })
 85            .clone()
 86            .await;
 87    }
 88
 89    let ep_store = cx
 90        .update(|cx| EditPredictionStore::try_global(cx))
 91        .context("EditPredictionStore not initialized")?;
 92
 93    ep_store.update(&mut cx, |store, _cx| {
 94        let model = match provider {
 95            PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
 96            PredictionProvider::Zeta2(version) => {
 97                edit_prediction::EditPredictionModel::Zeta2 { version }
 98            }
 99            PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
100            PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
101            PredictionProvider::Teacher(..) | PredictionProvider::TeacherNonBatching(..) => {
102                unreachable!()
103            }
104        };
105        store.set_edit_prediction_model(model);
106    });
107    step_progress.set_substatus("configuring model");
108    let state = example.state.as_ref().context("state must be set")?;
109    let run_dir = RUN_DIR.join(&example.spec.name);
110
111    let updated_example = Arc::new(Mutex::new(example.clone()));
112    let current_run_ix = Arc::new(AtomicUsize::new(0));
113
114    let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
115    let debug_task = cx.background_spawn({
116        let updated_example = updated_example.clone();
117        let current_run_ix = current_run_ix.clone();
118        let run_dir = run_dir.clone();
119        async move {
120            while let Some(event) = debug_rx.next().await {
121                let run_ix = current_run_ix.load(SeqCst);
122                let mut updated_example = updated_example.lock().unwrap();
123
124                let run_dir = if repetition_count > 1 {
125                    run_dir.join(format!("{:03}", run_ix))
126                } else {
127                    run_dir.clone()
128                };
129
130                match event {
131                    DebugEvent::EditPredictionStarted(request) => {
132                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
133
134                        if let Some(prompt) = request.prompt {
135                            fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
136                            if matches!(provider, PredictionProvider::Zeta2(_)) {
137                                updated_example.prompt.get_or_insert(ExamplePrompt {
138                                    input: prompt,
139                                    expected_output: String::new(),
140                                    provider,
141                                });
142                            }
143                        }
144                    }
145                    DebugEvent::EditPredictionFinished(request) => {
146                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
147
148                        if let Some(output) = request.model_output {
149                            fs::write(run_dir.join("prediction_response.md"), &output)?;
150                            updated_example
151                                .predictions
152                                .last_mut()
153                                .unwrap()
154                                .actual_output = output;
155                        }
156                        if run_ix >= repetition_count {
157                            break;
158                        }
159                    }
160                    _ => {}
161                }
162            }
163            anyhow::Ok(())
164        }
165    });
166
167    for ix in 0..repetition_count {
168        current_run_ix.store(ix, SeqCst);
169        let run_dir = if repetition_count > 1 {
170            run_dir.join(format!("{:03}", ix))
171        } else {
172            run_dir.clone()
173        };
174
175        fs::create_dir_all(&run_dir)?;
176        if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
177            fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
178        }
179        #[cfg(unix)]
180        std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
181        #[cfg(windows)]
182        std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
183
184        updated_example
185            .lock()
186            .unwrap()
187            .predictions
188            .push(ExamplePrediction {
189                actual_patch: String::new(),
190                actual_output: String::new(),
191                provider,
192            });
193
194        step_progress.set_substatus("requesting prediction");
195        let prediction = ep_store
196            .update(&mut cx, |store, cx| {
197                store.request_prediction(
198                    &state.project,
199                    &state.buffer,
200                    state.cursor_position,
201                    cloud_llm_client::PredictEditsRequestTrigger::Cli,
202                    cx,
203                )
204            })
205            .await?;
206
207        let actual_patch = prediction
208            .and_then(|prediction| {
209                let prediction = prediction.prediction.ok()?;
210                prediction
211                    .edit_preview
212                    .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
213            })
214            .unwrap_or_default();
215
216        let has_prediction = !actual_patch.is_empty();
217
218        updated_example
219            .lock()
220            .unwrap()
221            .predictions
222            .last_mut()
223            .unwrap()
224            .actual_patch = actual_patch;
225
226        if ix == repetition_count - 1 {
227            let (info, style) = if has_prediction {
228                ("predicted", InfoStyle::Normal)
229            } else {
230                ("no prediction", InfoStyle::Warning)
231            };
232            step_progress.set_info(info, style);
233        }
234    }
235
236    ep_store.update(&mut cx, |store, _| {
237        store.remove_project(&state.project);
238    });
239    debug_task.await?;
240
241    *example = Arc::into_inner(updated_example)
242        .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
243        .into_inner()
244        .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
245    Ok(())
246}
247
248async fn predict_anthropic(
249    example: &mut Example,
250    _repetition_count: usize,
251    version: ZetaVersion,
252    batched: bool,
253) -> anyhow::Result<()> {
254    let llm_model_name = "claude-sonnet-4-5";
255    let max_tokens = 16384;
256    let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
257        let client = if batched {
258            AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
259        } else {
260            AnthropicClient::plain()
261        };
262        client.expect("Failed to create Anthropic client")
263    });
264
265    let prompt = example.prompt.as_ref().context("Prompt is required")?;
266
267    let messages = vec![anthropic::Message {
268        role: anthropic::Role::User,
269        content: vec![anthropic::RequestContent::Text {
270            text: prompt.input.clone(),
271            cache_control: None,
272        }],
273    }];
274
275    let Some(response) = llm_client
276        .generate(llm_model_name, max_tokens, messages)
277        .await?
278    else {
279        // Request stashed for batched processing
280        return Ok(());
281    };
282
283    let actual_output = response
284        .content
285        .into_iter()
286        .filter_map(|content| match content {
287            anthropic::ResponseContent::Text { text } => Some(text),
288            _ => None,
289        })
290        .collect::<Vec<String>>()
291        .join("\n");
292
293    let actual_patch = TeacherPrompt::parse(&example, &actual_output)?;
294
295    let prediction = ExamplePrediction {
296        actual_patch,
297        actual_output,
298        provider: if batched {
299            PredictionProvider::Teacher(version)
300        } else {
301            PredictionProvider::TeacherNonBatching(version)
302        },
303    };
304
305    example.predictions.push(prediction);
306    Ok(())
307}
308
309pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
310    match provider {
311        PredictionProvider::Teacher(..) => {
312            let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
313                AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
314                    .expect("Failed to create Anthropic client")
315            });
316            llm_client
317                .sync_batches()
318                .await
319                .context("Failed to sync batches")?;
320        }
321        _ => (),
322    };
323    Ok(())
324}