predict.rs

  1use crate::{
  2    PredictionProvider, PromptFormat,
  3    anthropic_client::AnthropicClient,
  4    example::{Example, ExamplePrediction, ExamplePrompt},
  5    format_prompt::{TeacherPrompt, run_format_prompt},
  6    headless::EpAppState,
  7    load_project::run_load_project,
  8    paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
  9    progress::{InfoStyle, Progress, Step},
 10    retrieve_context::run_context_retrieval,
 11};
 12use anyhow::Context as _;
 13use edit_prediction::{DebugEvent, EditPredictionStore};
 14use futures::{FutureExt as _, StreamExt as _, future::Shared};
 15use gpui::{AppContext as _, AsyncApp, Task};
 16use std::{
 17    fs,
 18    sync::{
 19        Arc, Mutex, OnceLock,
 20        atomic::{AtomicUsize, Ordering::SeqCst},
 21    },
 22};
 23
 24static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
 25
 26pub async fn run_prediction(
 27    example: &mut Example,
 28    provider: Option<PredictionProvider>,
 29    repetition_count: usize,
 30    app_state: Arc<EpAppState>,
 31    mut cx: AsyncApp,
 32) -> anyhow::Result<()> {
 33    let provider = provider.context("provider is required")?;
 34
 35    if let Some(existing_prediction) = example.predictions.first() {
 36        if existing_prediction.provider == provider {
 37            return Ok(());
 38        } else {
 39            example.predictions.clear();
 40        }
 41    }
 42
 43    run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
 44
 45    if matches!(
 46        provider,
 47        PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
 48    ) {
 49        let _step_progress = Progress::global().start(Step::Predict, &example.spec.name);
 50
 51        run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?;
 52
 53        let batched = matches!(provider, PredictionProvider::Teacher);
 54        return predict_anthropic(example, repetition_count, batched).await;
 55    }
 56
 57    run_load_project(example, app_state.clone(), cx.clone()).await?;
 58
 59    let step_progress = Progress::global().start(Step::Predict, &example.spec.name);
 60
 61    if matches!(
 62        provider,
 63        PredictionProvider::Zeta1 | PredictionProvider::Zeta2
 64    ) {
 65        step_progress.set_substatus("authenticating");
 66        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
 67        AUTHENTICATED
 68            .get_or_init(|| {
 69                let client = app_state.client.clone();
 70                cx.spawn(async move |cx| {
 71                    if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
 72                        eprintln!("Authentication failed: {}", e);
 73                    }
 74                })
 75                .shared()
 76            })
 77            .clone()
 78            .await;
 79    }
 80
 81    let ep_store = cx
 82        .update(|cx| EditPredictionStore::try_global(cx))
 83        .context("EditPredictionStore not initialized")?;
 84
 85    ep_store.update(&mut cx, |store, _cx| {
 86        let model = match provider {
 87            PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
 88            PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
 89            PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
 90            PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
 91            PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
 92                unreachable!()
 93            }
 94        };
 95        store.set_edit_prediction_model(model);
 96    });
 97    step_progress.set_substatus("configuring model");
 98    let state = example.state.as_ref().context("state must be set")?;
 99    let run_dir = RUN_DIR.join(&example.spec.name);
100
101    let updated_example = Arc::new(Mutex::new(example.clone()));
102    let current_run_ix = Arc::new(AtomicUsize::new(0));
103
104    let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
105    let debug_task = cx.background_spawn({
106        let updated_example = updated_example.clone();
107        let current_run_ix = current_run_ix.clone();
108        let run_dir = run_dir.clone();
109        async move {
110            while let Some(event) = debug_rx.next().await {
111                let run_ix = current_run_ix.load(SeqCst);
112                let mut updated_example = updated_example.lock().unwrap();
113
114                let run_dir = if repetition_count > 1 {
115                    run_dir.join(format!("{:03}", run_ix))
116                } else {
117                    run_dir.clone()
118                };
119
120                match event {
121                    DebugEvent::EditPredictionStarted(request) => {
122                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
123
124                        if let Some(prompt) = request.prompt {
125                            fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
126                            if provider == PredictionProvider::Zeta2 {
127                                updated_example.prompt.get_or_insert(ExamplePrompt {
128                                    input: prompt,
129                                    expected_output: String::new(),
130                                    format: PromptFormat::Zeta2,
131                                });
132                            }
133                        }
134                    }
135                    DebugEvent::EditPredictionFinished(request) => {
136                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
137
138                        if let Some(output) = request.model_output {
139                            fs::write(run_dir.join("prediction_response.md"), &output)?;
140                            updated_example
141                                .predictions
142                                .last_mut()
143                                .unwrap()
144                                .actual_output = output;
145                        }
146                        if run_ix >= repetition_count {
147                            break;
148                        }
149                    }
150                    _ => {}
151                }
152            }
153            anyhow::Ok(())
154        }
155    });
156
157    for ix in 0..repetition_count {
158        current_run_ix.store(ix, SeqCst);
159        let run_dir = if repetition_count > 1 {
160            run_dir.join(format!("{:03}", ix))
161        } else {
162            run_dir.clone()
163        };
164
165        fs::create_dir_all(&run_dir)?;
166        if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
167            fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
168        }
169        #[cfg(unix)]
170        std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
171        #[cfg(windows)]
172        std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
173
174        updated_example
175            .lock()
176            .unwrap()
177            .predictions
178            .push(ExamplePrediction {
179                actual_patch: String::new(),
180                actual_output: String::new(),
181                provider,
182            });
183
184        step_progress.set_substatus("requesting prediction");
185        let prediction = ep_store
186            .update(&mut cx, |store, cx| {
187                store.request_prediction(
188                    &state.project,
189                    &state.buffer,
190                    state.cursor_position,
191                    cloud_llm_client::PredictEditsRequestTrigger::Cli,
192                    cx,
193                )
194            })
195            .await?;
196
197        let actual_patch = prediction
198            .and_then(|prediction| {
199                let prediction = prediction.prediction.ok()?;
200                prediction
201                    .edit_preview
202                    .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
203            })
204            .unwrap_or_default();
205
206        let has_prediction = !actual_patch.is_empty();
207
208        updated_example
209            .lock()
210            .unwrap()
211            .predictions
212            .last_mut()
213            .unwrap()
214            .actual_patch = actual_patch;
215
216        if ix == repetition_count - 1 {
217            let (info, style) = if has_prediction {
218                ("predicted", InfoStyle::Normal)
219            } else {
220                ("no prediction", InfoStyle::Warning)
221            };
222            step_progress.set_info(info, style);
223        }
224    }
225
226    ep_store.update(&mut cx, |store, _| {
227        store.remove_project(&state.project);
228    });
229    debug_task.await?;
230
231    *example = Arc::into_inner(updated_example)
232        .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
233        .into_inner()
234        .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
235    Ok(())
236}
237
238async fn predict_anthropic(
239    example: &mut Example,
240    _repetition_count: usize,
241    batched: bool,
242) -> anyhow::Result<()> {
243    let llm_model_name = "claude-sonnet-4-5";
244    let max_tokens = 16384;
245    let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
246        let client = if batched {
247            AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
248        } else {
249            AnthropicClient::plain()
250        };
251        client.expect("Failed to create Anthropic client")
252    });
253
254    let prompt = example.prompt.as_ref().context("Prompt is required")?;
255
256    let messages = vec![anthropic::Message {
257        role: anthropic::Role::User,
258        content: vec![anthropic::RequestContent::Text {
259            text: prompt.input.clone(),
260            cache_control: None,
261        }],
262    }];
263
264    let Some(response) = llm_client
265        .generate(llm_model_name, max_tokens, messages)
266        .await?
267    else {
268        // Request stashed for batched processing
269        return Ok(());
270    };
271
272    let actual_output = response
273        .content
274        .into_iter()
275        .filter_map(|content| match content {
276            anthropic::ResponseContent::Text { text } => Some(text),
277            _ => None,
278        })
279        .collect::<Vec<String>>()
280        .join("\n");
281
282    let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
283
284    let prediction = ExamplePrediction {
285        actual_patch,
286        actual_output,
287        provider: PredictionProvider::Teacher,
288    };
289
290    example.predictions.push(prediction);
291    Ok(())
292}
293
294pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
295    match provider {
296        PredictionProvider::Teacher => {
297            let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
298                AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
299                    .expect("Failed to create Anthropic client")
300            });
301            llm_client
302                .sync_batches()
303                .await
304                .context("Failed to sync batches")?;
305        }
306        _ => (),
307    };
308    Ok(())
309}