predict.rs

  1use crate::{
  2    FormatPromptArgs, PredictArgs, PredictionProvider,
  3    anthropic_client::AnthropicClient,
  4    example::{Example, ExamplePrediction, ExamplePrompt},
  5    format_prompt::{TeacherPrompt, run_format_prompt},
  6    headless::EpAppState,
  7    load_project::run_load_project,
  8    paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
  9    progress::{ExampleProgress, InfoStyle, Step},
 10    retrieve_context::run_context_retrieval,
 11};
 12use anyhow::Context as _;
 13use edit_prediction::{DebugEvent, EditPredictionStore};
 14use futures::{FutureExt as _, StreamExt as _, future::Shared};
 15use gpui::{AppContext as _, AsyncApp, Task};
 16use std::{
 17    fs,
 18    sync::{
 19        Arc, Mutex, OnceLock,
 20        atomic::{AtomicUsize, Ordering::SeqCst},
 21    },
 22};
 23use zeta_prompt::ZetaVersion;
 24
 25static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
 26
 27pub async fn run_prediction(
 28    example: &mut Example,
 29    args: &PredictArgs,
 30    app_state: Arc<EpAppState>,
 31    example_progress: &ExampleProgress,
 32    mut cx: AsyncApp,
 33) -> anyhow::Result<()> {
 34    let repetition_count = args.repetitions;
 35
 36    if let Some(existing_prediction) = example.predictions.first() {
 37        let has_prediction = existing_prediction.actual_patch.is_some()
 38            || !existing_prediction.actual_output.is_empty();
 39        if has_prediction {
 40            match args.provider {
 41                None => return Ok(()),
 42                Some(provider) if existing_prediction.provider == provider => return Ok(()),
 43                Some(_) => example.predictions.clear(),
 44            }
 45        }
 46    }
 47
 48    let Some(provider) = args.provider else {
 49        anyhow::bail!(
 50            "No existing predictions found. Use --provider to specify which model to use for prediction."
 51        );
 52    };
 53
 54    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 55
 56    if let PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) =
 57        provider
 58    {
 59        let _step_progress = example_progress.start(Step::Predict);
 60
 61        run_format_prompt(
 62            example,
 63            &FormatPromptArgs { provider },
 64            app_state.clone(),
 65            example_progress,
 66            cx,
 67        )
 68        .await?;
 69
 70        let batched = matches!(provider, PredictionProvider::Teacher(..));
 71        return predict_anthropic(example, repetition_count, version, batched).await;
 72    }
 73
 74    run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
 75
 76    let step_progress = example_progress.start(Step::Predict);
 77
 78    if matches!(
 79        provider,
 80        PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
 81    ) {
 82        step_progress.set_substatus("authenticating");
 83        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
 84        AUTHENTICATED
 85            .get_or_init(|| {
 86                let client = app_state.client.clone();
 87                cx.spawn(async move |cx| {
 88                    if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
 89                        eprintln!("Authentication failed: {}", e);
 90                    }
 91                })
 92                .shared()
 93            })
 94            .clone()
 95            .await;
 96    }
 97
 98    let ep_store = cx
 99        .update(|cx| EditPredictionStore::try_global(cx))
100        .context("EditPredictionStore not initialized")?;
101
102    ep_store.update(&mut cx, |store, _cx| {
103        let model = match provider {
104            PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
105            PredictionProvider::Zeta2(version) => {
106                edit_prediction::EditPredictionModel::Zeta2 { version }
107            }
108            PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
109            PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
110            PredictionProvider::Teacher(..)
111            | PredictionProvider::TeacherNonBatching(..)
112            | PredictionProvider::Repair => {
113                unreachable!()
114            }
115        };
116        store.set_edit_prediction_model(model);
117    });
118    step_progress.set_substatus("configuring model");
119    let state = example.state.as_ref().context("state must be set")?;
120    let run_dir = RUN_DIR.join(&example.spec.name);
121
122    let updated_example = Arc::new(Mutex::new(example.clone()));
123    let current_run_ix = Arc::new(AtomicUsize::new(0));
124
125    let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
126    let debug_task = cx.background_spawn({
127        let updated_example = updated_example.clone();
128        let current_run_ix = current_run_ix.clone();
129        let run_dir = run_dir.clone();
130        async move {
131            while let Some(event) = debug_rx.next().await {
132                let run_ix = current_run_ix.load(SeqCst);
133                let mut updated_example = updated_example.lock().unwrap();
134
135                let run_dir = if repetition_count > 1 {
136                    run_dir.join(format!("{:03}", run_ix))
137                } else {
138                    run_dir.clone()
139                };
140
141                match event {
142                    DebugEvent::EditPredictionStarted(request) => {
143                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
144
145                        if let Some(prompt) = request.prompt {
146                            fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
147                            if matches!(provider, PredictionProvider::Zeta2(_)) {
148                                updated_example.prompt.get_or_insert(ExamplePrompt {
149                                    input: prompt,
150                                    expected_output: String::new(),
151                                    provider,
152                                });
153                            }
154                        }
155                    }
156                    DebugEvent::EditPredictionFinished(request) => {
157                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
158
159                        if let Some(output) = request.model_output {
160                            fs::write(run_dir.join("prediction_response.md"), &output)?;
161                            updated_example
162                                .predictions
163                                .last_mut()
164                                .unwrap()
165                                .actual_output = output;
166                        }
167                        if run_ix >= repetition_count {
168                            break;
169                        }
170                    }
171                    _ => {}
172                }
173            }
174            anyhow::Ok(())
175        }
176    });
177
178    for ix in 0..repetition_count {
179        current_run_ix.store(ix, SeqCst);
180        let run_dir = if repetition_count > 1 {
181            run_dir.join(format!("{:03}", ix))
182        } else {
183            run_dir.clone()
184        };
185
186        fs::create_dir_all(&run_dir)?;
187        if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
188            fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
189        }
190        #[cfg(unix)]
191        std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
192        #[cfg(windows)]
193        std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
194
195        updated_example
196            .lock()
197            .unwrap()
198            .predictions
199            .push(ExamplePrediction {
200                actual_patch: None,
201                actual_output: String::new(),
202                error: None,
203                provider,
204            });
205
206        step_progress.set_substatus("requesting prediction");
207        let prediction = ep_store
208            .update(&mut cx, |store, cx| {
209                store.request_prediction(
210                    &state.project,
211                    &state.buffer,
212                    state.cursor_position,
213                    cloud_llm_client::PredictEditsRequestTrigger::Cli,
214                    cx,
215                )
216            })
217            .await?;
218
219        let actual_patch = prediction.and_then(|prediction| {
220            let prediction = prediction.prediction.ok()?;
221            prediction
222                .edit_preview
223                .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
224        });
225
226        let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
227
228        updated_example
229            .lock()
230            .unwrap()
231            .predictions
232            .last_mut()
233            .unwrap()
234            .actual_patch = actual_patch;
235
236        if ix == repetition_count - 1 {
237            let (info, style) = if has_prediction {
238                ("predicted", InfoStyle::Normal)
239            } else {
240                ("no prediction", InfoStyle::Warning)
241            };
242            step_progress.set_info(info, style);
243        }
244    }
245
246    ep_store.update(&mut cx, |store, _| {
247        store.remove_project(&state.project);
248    });
249    debug_task.await?;
250
251    *example = Arc::into_inner(updated_example)
252        .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
253        .into_inner()
254        .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
255    Ok(())
256}
257
258async fn predict_anthropic(
259    example: &mut Example,
260    _repetition_count: usize,
261    version: ZetaVersion,
262    batched: bool,
263) -> anyhow::Result<()> {
264    let llm_model_name = "claude-sonnet-4-5";
265    let max_tokens = 16384;
266    let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
267        let client = if batched {
268            AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
269        } else {
270            AnthropicClient::plain()
271        };
272        client.expect("Failed to create Anthropic client")
273    });
274
275    let prompt = example.prompt.as_ref().context("Prompt is required")?;
276
277    let messages = vec![anthropic::Message {
278        role: anthropic::Role::User,
279        content: vec![anthropic::RequestContent::Text {
280            text: prompt.input.clone(),
281            cache_control: None,
282        }],
283    }];
284
285    let Some(response) = llm_client
286        .generate(llm_model_name, max_tokens, messages)
287        .await?
288    else {
289        // Request stashed for batched processing
290        return Ok(());
291    };
292
293    let actual_output = response
294        .content
295        .into_iter()
296        .filter_map(|content| match content {
297            anthropic::ResponseContent::Text { text } => Some(text),
298            _ => None,
299        })
300        .collect::<Vec<String>>()
301        .join("\n");
302
303    let actual_patch = TeacherPrompt::parse(&example, &actual_output)?;
304
305    let prediction = ExamplePrediction {
306        actual_patch: Some(actual_patch),
307        actual_output,
308        error: None,
309        provider: if batched {
310            PredictionProvider::Teacher(version)
311        } else {
312            PredictionProvider::TeacherNonBatching(version)
313        },
314    };
315
316    example.predictions.push(prediction);
317    Ok(())
318}
319
320pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
321    match provider {
322        Some(PredictionProvider::Teacher(..)) => {
323            let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
324                AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
325                    .expect("Failed to create Anthropic client")
326            });
327            llm_client
328                .sync_batches()
329                .await
330                .context("Failed to sync batches")?;
331        }
332        _ => (),
333    };
334    Ok(())
335}