predict.rs

  1use crate::{
  2    BatchProvider, FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
  3    anthropic_client::AnthropicClient,
  4    example::{Example, ExamplePrediction, ExamplePrompt},
  5    format_prompt::{TeacherPrompt, run_format_prompt},
  6    headless::EpAppState,
  7    llm_client::{LlmClient, model_for_backend},
  8    load_project::run_load_project,
  9    openai_client::OpenAiClient,
 10    paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
 11    progress::{ExampleProgress, InfoStyle, Step},
 12    qa,
 13    repair::{build_repair_prompt, needs_repair, parse_repair_response},
 14    retrieve_context::run_context_retrieval,
 15};
 16use anyhow::Context as _;
 17use edit_prediction::{DebugEvent, EditPredictionStore};
 18use futures::{FutureExt as _, StreamExt as _, future::Shared};
 19use gpui::{AppContext as _, AsyncApp, Task};
 20use std::{
 21    fs,
 22    sync::{
 23        Arc, Mutex, OnceLock,
 24        atomic::{AtomicUsize, Ordering::SeqCst},
 25    },
 26};
 27
 28static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
 29static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
 30
 31pub async fn run_prediction(
 32    example: &mut Example,
 33    args: &PredictArgs,
 34    app_state: Arc<EpAppState>,
 35    example_progress: &ExampleProgress,
 36    mut cx: AsyncApp,
 37) -> anyhow::Result<()> {
 38    let repetition_count = args.repetitions;
 39
 40    if let Some(existing_prediction) = example.predictions.first() {
 41        let has_prediction = existing_prediction.actual_patch.is_some()
 42            || !existing_prediction.actual_output.is_empty();
 43        if has_prediction {
 44            match args.provider {
 45                None => return Ok(()),
 46                Some(provider) if existing_prediction.provider == provider => return Ok(()),
 47                Some(_) => example.predictions.clear(),
 48            }
 49        }
 50    }
 51
 52    let Some(provider) = args.provider else {
 53        anyhow::bail!(
 54            "No existing predictions found. Use --provider to specify which model to use for prediction."
 55        );
 56    };
 57
 58    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 59
 60    if let PredictionProvider::Teacher(backend) | PredictionProvider::TeacherNonBatching(backend) =
 61        provider
 62    {
 63        let _step_progress = example_progress.start(Step::Predict);
 64
 65        run_format_prompt(
 66            example,
 67            &FormatPromptArgs { provider },
 68            app_state.clone(),
 69            example_progress,
 70            cx,
 71        )
 72        .await?;
 73
 74        let batched = matches!(provider, PredictionProvider::Teacher(..));
 75        return predict_teacher(example, backend, batched, repetition_count).await;
 76    }
 77
 78    if let PredictionProvider::RepairedTeacher(backend) = provider {
 79        let _step_progress = example_progress.start(Step::Predict);
 80
 81        run_format_prompt(
 82            example,
 83            &FormatPromptArgs { provider },
 84            app_state.clone(),
 85            example_progress,
 86            cx,
 87        )
 88        .await?;
 89
 90        return predict_repaired_teacher(example, backend, repetition_count).await;
 91    }
 92
 93    run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
 94
 95    let step_progress = example_progress.start(Step::Predict);
 96
 97    if matches!(
 98        provider,
 99        PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
100    ) {
101        step_progress.set_substatus("authenticating");
102        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
103        AUTHENTICATED
104            .get_or_init(|| {
105                let client = app_state.client.clone();
106                cx.spawn(async move |cx| {
107                    if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
108                        eprintln!("Authentication failed: {}", e);
109                    }
110                })
111                .shared()
112            })
113            .clone()
114            .await;
115    }
116
117    let ep_store = cx
118        .update(|cx| EditPredictionStore::try_global(cx))
119        .context("EditPredictionStore not initialized")?;
120
121    ep_store.update(&mut cx, |store, _cx| {
122        let model = match provider {
123            PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
124            PredictionProvider::Zeta2(version) => {
125                edit_prediction::EditPredictionModel::Zeta2 { version }
126            }
127            PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
128            PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
129            PredictionProvider::Teacher(..)
130            | PredictionProvider::TeacherNonBatching(..)
131            | PredictionProvider::RepairedTeacher(..)
132            | PredictionProvider::Repair => {
133                unreachable!()
134            }
135        };
136        store.set_edit_prediction_model(model);
137    });
138    step_progress.set_substatus("configuring model");
139    let state = example.state.as_ref().context("state must be set")?;
140    let run_dir = RUN_DIR.join(&example.spec.name);
141
142    let updated_example = Arc::new(Mutex::new(example.clone()));
143    let current_run_ix = Arc::new(AtomicUsize::new(0));
144
145    let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
146    let debug_task = cx.background_spawn({
147        let updated_example = updated_example.clone();
148        let current_run_ix = current_run_ix.clone();
149        let run_dir = run_dir.clone();
150        async move {
151            while let Some(event) = debug_rx.next().await {
152                let run_ix = current_run_ix.load(SeqCst);
153                let mut updated_example = updated_example.lock().unwrap();
154
155                let run_dir = if repetition_count > 1 {
156                    run_dir.join(format!("{:03}", run_ix))
157                } else {
158                    run_dir.clone()
159                };
160
161                match event {
162                    DebugEvent::EditPredictionStarted(request) => {
163                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
164
165                        if let Some(prompt) = request.prompt {
166                            fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
167                            if matches!(provider, PredictionProvider::Zeta2(_)) {
168                                updated_example.prompt.get_or_insert(ExamplePrompt {
169                                    input: prompt,
170                                    expected_output: String::new(),
171                                    rejected_output: None,
172                                    provider,
173                                });
174                            }
175                        }
176                    }
177                    DebugEvent::EditPredictionFinished(request) => {
178                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
179
180                        if let Some(output) = request.model_output {
181                            fs::write(run_dir.join("prediction_response.md"), &output)?;
182                            updated_example
183                                .predictions
184                                .last_mut()
185                                .unwrap()
186                                .actual_output = output;
187                        }
188                        if run_ix >= repetition_count {
189                            break;
190                        }
191                    }
192                    _ => {}
193                }
194            }
195            anyhow::Ok(())
196        }
197    });
198
199    for ix in 0..repetition_count {
200        current_run_ix.store(ix, SeqCst);
201        let run_dir = if repetition_count > 1 {
202            run_dir.join(format!("{:03}", ix))
203        } else {
204            run_dir.clone()
205        };
206
207        fs::create_dir_all(&run_dir)?;
208        if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
209            fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
210        }
211        #[cfg(unix)]
212        std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
213        #[cfg(windows)]
214        std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
215
216        updated_example
217            .lock()
218            .unwrap()
219            .predictions
220            .push(ExamplePrediction {
221                actual_patch: None,
222                actual_output: String::new(),
223                error: None,
224                provider,
225            });
226
227        step_progress.set_substatus("requesting prediction");
228        let prediction = ep_store
229            .update(&mut cx, |store, cx| {
230                store.request_prediction(
231                    &state.project,
232                    &state.buffer,
233                    state.cursor_position,
234                    cloud_llm_client::PredictEditsRequestTrigger::Cli,
235                    cx,
236                )
237            })
238            .await?;
239
240        let actual_patch = prediction.and_then(|prediction| {
241            let prediction = prediction.prediction.ok()?;
242            prediction
243                .edit_preview
244                .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
245        });
246
247        let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
248
249        updated_example
250            .lock()
251            .unwrap()
252            .predictions
253            .last_mut()
254            .unwrap()
255            .actual_patch = actual_patch;
256
257        if ix == repetition_count - 1 {
258            let (info, style) = if has_prediction {
259                ("predicted", InfoStyle::Normal)
260            } else {
261                ("no prediction", InfoStyle::Warning)
262            };
263            step_progress.set_info(info, style);
264        }
265    }
266
267    ep_store.update(&mut cx, |store, _| {
268        store.remove_project(&state.project);
269    });
270    debug_task.await?;
271
272    *example = Arc::into_inner(updated_example)
273        .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
274        .into_inner()
275        .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
276    Ok(())
277}
278
279async fn predict_teacher(
280    example: &mut Example,
281    backend: TeacherBackend,
282    batched: bool,
283    repetition_count: usize,
284) -> anyhow::Result<()> {
285    match backend {
286        TeacherBackend::Sonnet45 => {
287            predict_anthropic(example, backend, batched, repetition_count).await
288        }
289        TeacherBackend::Gpt52 => predict_openai(example, backend, batched, repetition_count).await,
290    }
291}
292
293async fn predict_anthropic(
294    example: &mut Example,
295    backend: TeacherBackend,
296    batched: bool,
297    repetition_count: usize,
298) -> anyhow::Result<()> {
299    let llm_model_name = backend.model_name();
300    let max_tokens = 16384;
301    let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
302        let client = if batched {
303            AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
304        } else {
305            AnthropicClient::plain()
306        };
307        client.expect("Failed to create Anthropic client")
308    });
309
310    let prompt = example.prompt.as_ref().context("Prompt is required")?;
311
312    for ix in 0..repetition_count {
313        let messages = vec![anthropic::Message {
314            role: anthropic::Role::User,
315            content: vec![anthropic::RequestContent::Text {
316                text: prompt.input.clone(),
317                cache_control: None,
318            }],
319        }];
320
321        let seed = if repetition_count > 1 { Some(ix) } else { None };
322        let Some(response) = llm_client
323            .generate(llm_model_name, max_tokens, messages, seed)
324            .await?
325        else {
326            // Request stashed for batched processing
327            return Ok(());
328        };
329
330        let actual_output = response
331            .content
332            .into_iter()
333            .filter_map(|content| match content {
334                anthropic::ResponseContent::Text { text } => Some(text),
335                _ => None,
336            })
337            .collect::<Vec<String>>()
338            .join("\n");
339
340        let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
341
342        let prediction = ExamplePrediction {
343            actual_patch: Some(actual_patch),
344            actual_output,
345            error: None,
346            provider: if batched {
347                PredictionProvider::Teacher(backend)
348            } else {
349                PredictionProvider::TeacherNonBatching(backend)
350            },
351        };
352
353        example.predictions.push(prediction);
354    }
355    Ok(())
356}
357
358async fn predict_openai(
359    example: &mut Example,
360    backend: TeacherBackend,
361    batched: bool,
362    repetition_count: usize,
363) -> anyhow::Result<()> {
364    let llm_model_name = backend.model_name();
365    let max_tokens = 16384;
366    let llm_client = OPENAI_CLIENT.get_or_init(|| {
367        let client = if batched {
368            OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
369        } else {
370            OpenAiClient::plain()
371        };
372        client.expect("Failed to create OpenAI client")
373    });
374
375    let prompt = example.prompt.as_ref().context("Prompt is required")?;
376
377    for ix in 0..repetition_count {
378        let messages = vec![open_ai::RequestMessage::User {
379            content: open_ai::MessageContent::Plain(prompt.input.clone()),
380        }];
381
382        let seed = if repetition_count > 1 { Some(ix) } else { None };
383        let Some(response) = llm_client
384            .generate(llm_model_name, max_tokens, messages, seed)
385            .await?
386        else {
387            // Request stashed for batched processing
388            return Ok(());
389        };
390
391        let actual_output = response
392            .choices
393            .into_iter()
394            .filter_map(|choice| match choice.message {
395                open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
396                    open_ai::MessageContent::Plain(text) => text,
397                    open_ai::MessageContent::Multipart(parts) => parts
398                        .into_iter()
399                        .filter_map(|p| match p {
400                            open_ai::MessagePart::Text { text } => Some(text),
401                            _ => None,
402                        })
403                        .collect::<Vec<_>>()
404                        .join(""),
405                }),
406                _ => None,
407            })
408            .collect::<Vec<String>>()
409            .join("\n");
410
411        let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
412
413        let prediction = ExamplePrediction {
414            actual_patch: Some(actual_patch),
415            actual_output,
416            error: None,
417            provider: if batched {
418                PredictionProvider::Teacher(backend)
419            } else {
420                PredictionProvider::TeacherNonBatching(backend)
421            },
422        };
423
424        example.predictions.push(prediction);
425    }
426    Ok(())
427}
428
429/// Default confidence threshold for repair
430const DEFAULT_REPAIR_CONFIDENCE_THRESHOLD: u8 = 3;
431
432/// Predict using teacher model, then run QA evaluation, and optionally repair
433/// if QA indicates issues (reverts_edits=true or low confidence).
434///
435/// This is a non-batched flow that processes each step synchronously.
436async fn predict_repaired_teacher(
437    example: &mut Example,
438    backend: TeacherBackend,
439    repetition_count: usize,
440) -> anyhow::Result<()> {
441    // Step 1: Run teacher prediction (non-batched for immediate results)
442    predict_teacher(example, backend, false, repetition_count).await?;
443
444    // Only proceed with QA/repair for the first prediction
445    let Some(prediction) = example.predictions.first() else {
446        return Ok(());
447    };
448
449    // Skip QA if no actual patch was generated
450    if prediction.actual_patch.is_none() {
451        return Ok(());
452    }
453
454    // Step 2: Run QA evaluation
455    let batch_provider = match backend {
456        TeacherBackend::Sonnet45 => BatchProvider::Anthropic,
457        TeacherBackend::Gpt52 => BatchProvider::Openai,
458    };
459    let qa_client = LlmClient::new(batch_provider, false)?;
460    let qa_model = model_for_backend(batch_provider);
461
462    let qa_result = if let Some(qa_prompt) = qa::build_prompt(example) {
463        match qa_client.generate(qa_model, 1024, &qa_prompt).await? {
464            Some(response_text) => Some(qa::parse_response(&response_text)),
465            None => None,
466        }
467    } else {
468        None
469    };
470
471    // Store QA result
472    example.qa = vec![qa_result.clone()];
473
474    // Step 3: Check if repair is needed and run repair if so
475    if needs_repair(example, DEFAULT_REPAIR_CONFIDENCE_THRESHOLD) {
476        let repair_client = LlmClient::new(batch_provider, false)?;
477
478        if let Some(repair_prompt) = build_repair_prompt(example) {
479            if let Some(response_text) = repair_client
480                .generate(qa_model, 16384, &repair_prompt)
481                .await?
482            {
483                match parse_repair_response(example, &response_text) {
484                    Ok(mut repaired_prediction) => {
485                        // Mark the prediction as coming from repaired-teacher
486                        repaired_prediction.provider = PredictionProvider::RepairedTeacher(backend);
487                        example.predictions.push(repaired_prediction);
488                    }
489                    Err(e) => {
490                        // Add error prediction if parsing failed
491                        example.predictions.push(ExamplePrediction {
492                            actual_patch: None,
493                            actual_output: response_text,
494                            error: Some(format!("Failed to parse repair response: {}", e)),
495                            provider: PredictionProvider::RepairedTeacher(backend),
496                        });
497                    }
498                }
499            }
500        }
501    }
502
503    Ok(())
504}
505
506pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
507    match provider {
508        Some(PredictionProvider::Teacher(backend)) => match backend {
509            TeacherBackend::Sonnet45 => {
510                let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
511                    AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
512                        .expect("Failed to create Anthropic client")
513                });
514                llm_client
515                    .sync_batches()
516                    .await
517                    .context("Failed to sync Anthropic batches")?;
518            }
519            TeacherBackend::Gpt52 => {
520                let llm_client = OPENAI_CLIENT.get_or_init(|| {
521                    OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
522                        .expect("Failed to create OpenAI client")
523                });
524                llm_client
525                    .sync_batches()
526                    .await
527                    .context("Failed to sync OpenAI batches")?;
528            }
529        },
530        _ => (),
531    };
532    Ok(())
533}