predict.rs

  1use crate::{
  2    PredictionProvider, PromptFormat,
  3    anthropic_client::AnthropicClient,
  4    example::{Example, ExamplePrediction},
  5    format_prompt::{TeacherPrompt, run_format_prompt},
  6    headless::EpAppState,
  7    load_project::run_load_project,
  8    paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
  9    progress::{InfoStyle, Progress, Step},
 10    retrieve_context::run_context_retrieval,
 11};
 12use edit_prediction::{DebugEvent, EditPredictionStore};
 13use futures::{FutureExt as _, StreamExt as _, future::Shared};
 14use gpui::{AppContext as _, AsyncApp, Task};
 15use std::{
 16    fs,
 17    sync::{
 18        Arc, Mutex, OnceLock,
 19        atomic::{AtomicUsize, Ordering::SeqCst},
 20    },
 21};
 22
 23pub async fn run_prediction(
 24    example: &mut Example,
 25    provider: Option<PredictionProvider>,
 26    repetition_count: usize,
 27    app_state: Arc<EpAppState>,
 28    mut cx: AsyncApp,
 29) {
 30    if !example.predictions.is_empty() {
 31        return;
 32    }
 33
 34    let provider = provider.unwrap();
 35
 36    run_context_retrieval(example, app_state.clone(), cx.clone()).await;
 37
 38    if matches!(
 39        provider,
 40        PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
 41    ) {
 42        let _step_progress = Progress::global().start(Step::Predict, &example.name);
 43
 44        if example.prompt.is_none() {
 45            run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await;
 46        }
 47
 48        let batched = matches!(provider, PredictionProvider::Teacher);
 49        return predict_anthropic(example, repetition_count, batched).await;
 50    }
 51
 52    run_load_project(example, app_state.clone(), cx.clone()).await;
 53
 54    let _step_progress = Progress::global().start(Step::Predict, &example.name);
 55
 56    if matches!(
 57        provider,
 58        PredictionProvider::Zeta1 | PredictionProvider::Zeta2
 59    ) {
 60        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
 61        AUTHENTICATED
 62            .get_or_init(|| {
 63                let client = app_state.client.clone();
 64                cx.spawn(async move |cx| {
 65                    client
 66                        .sign_in_with_optional_connect(true, cx)
 67                        .await
 68                        .unwrap();
 69                })
 70                .shared()
 71            })
 72            .clone()
 73            .await;
 74    }
 75
 76    let ep_store = cx
 77        .update(|cx| EditPredictionStore::try_global(cx).unwrap())
 78        .unwrap();
 79
 80    ep_store
 81        .update(&mut cx, |store, _cx| {
 82            let model = match provider {
 83                PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
 84                PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
 85                PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
 86                PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
 87                PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
 88                    unreachable!()
 89                }
 90            };
 91            store.set_edit_prediction_model(model);
 92        })
 93        .unwrap();
 94    let state = example.state.as_ref().unwrap();
 95    let run_dir = RUN_DIR.join(&example.name);
 96
 97    let updated_example = Arc::new(Mutex::new(example.clone()));
 98    let current_run_ix = Arc::new(AtomicUsize::new(0));
 99
100    let mut debug_rx = ep_store
101        .update(&mut cx, |store, cx| store.debug_info(&state.project, cx))
102        .unwrap();
103    let debug_task = cx.background_spawn({
104        let updated_example = updated_example.clone();
105        let current_run_ix = current_run_ix.clone();
106        let run_dir = run_dir.clone();
107        async move {
108            while let Some(event) = debug_rx.next().await {
109                let run_ix = current_run_ix.load(SeqCst);
110                let mut updated_example = updated_example.lock().unwrap();
111
112                let run_dir = if repetition_count > 1 {
113                    run_dir.join(format!("{:03}", run_ix))
114                } else {
115                    run_dir.clone()
116                };
117
118                match event {
119                    DebugEvent::EditPredictionStarted(request) => {
120                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
121
122                        if let Some(prompt) = request.prompt {
123                            fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
124                        }
125                    }
126                    DebugEvent::EditPredictionFinished(request) => {
127                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
128
129                        if let Some(output) = request.model_output {
130                            fs::write(run_dir.join("prediction_response.md"), &output)?;
131                            updated_example
132                                .predictions
133                                .last_mut()
134                                .unwrap()
135                                .actual_output = output;
136                        }
137                        if run_ix >= repetition_count {
138                            break;
139                        }
140                    }
141                    _ => {}
142                }
143            }
144            anyhow::Ok(())
145        }
146    });
147
148    for ix in 0..repetition_count {
149        current_run_ix.store(ix, SeqCst);
150        let run_dir = if repetition_count > 1 {
151            run_dir.join(format!("{:03}", ix))
152        } else {
153            run_dir.clone()
154        };
155
156        fs::create_dir_all(&run_dir).unwrap();
157        if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
158            fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap();
159        }
160        #[cfg(unix)]
161        std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
162        #[cfg(windows)]
163        std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
164
165        updated_example
166            .lock()
167            .unwrap()
168            .predictions
169            .push(ExamplePrediction {
170                actual_patch: String::new(),
171                actual_output: String::new(),
172                provider,
173            });
174
175        let prediction = ep_store
176            .update(&mut cx, |store, cx| {
177                store.request_prediction(
178                    &state.project,
179                    &state.buffer,
180                    state.cursor_position,
181                    cloud_llm_client::PredictEditsRequestTrigger::Cli,
182                    cx,
183                )
184            })
185            .unwrap()
186            .await
187            .unwrap();
188
189        let actual_patch = prediction
190            .and_then(|prediction| {
191                let prediction = prediction.prediction.ok()?;
192                prediction.edit_preview.as_unified_diff(&prediction.edits)
193            })
194            .unwrap_or_default();
195
196        let has_prediction = !actual_patch.is_empty();
197
198        updated_example
199            .lock()
200            .unwrap()
201            .predictions
202            .last_mut()
203            .unwrap()
204            .actual_patch = actual_patch;
205
206        if ix == repetition_count - 1 {
207            let (info, style) = if has_prediction {
208                ("predicted", InfoStyle::Normal)
209            } else {
210                ("no prediction", InfoStyle::Warning)
211            };
212            _step_progress.set_info(info, style);
213        }
214    }
215
216    ep_store
217        .update(&mut cx, |store, _| {
218            store.remove_project(&state.project);
219        })
220        .unwrap();
221    debug_task.await.unwrap();
222
223    *example = Arc::into_inner(updated_example)
224        .unwrap()
225        .into_inner()
226        .unwrap();
227}
228
229async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) {
230    let llm_model_name = "claude-sonnet-4-5";
231    let max_tokens = 16384;
232    let llm_client = if batched {
233        AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref())
234    } else {
235        AnthropicClient::plain()
236    };
237    let llm_client = llm_client.expect("Failed to create LLM client");
238
239    let prompt = example
240        .prompt
241        .as_ref()
242        .unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name));
243
244    let messages = vec![anthropic::Message {
245        role: anthropic::Role::User,
246        content: vec![anthropic::RequestContent::Text {
247            text: prompt.input.clone(),
248            cache_control: None,
249        }],
250    }];
251
252    let Some(response) = llm_client
253        .generate(llm_model_name, max_tokens, messages)
254        .await
255        .unwrap()
256    else {
257        // Request stashed for batched processing
258        return;
259    };
260
261    let actual_output = response
262        .content
263        .into_iter()
264        .filter_map(|content| match content {
265            anthropic::ResponseContent::Text { text } => Some(text),
266            _ => None,
267        })
268        .collect::<Vec<String>>()
269        .join("\n");
270
271    let actual_patch = TeacherPrompt::parse(example, &actual_output);
272
273    let prediction = ExamplePrediction {
274        actual_patch,
275        actual_output,
276        provider: PredictionProvider::Teacher,
277    };
278
279    example.predictions.push(prediction);
280}
281
282pub async fn sync_batches(provider: &PredictionProvider) {
283    match provider {
284        PredictionProvider::Teacher => {
285            let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
286            let llm_client =
287                AnthropicClient::batch(cache_path).expect("Failed to create LLM client");
288            llm_client
289                .sync_batches()
290                .await
291                .expect("Failed to sync batches");
292        }
293        _ => (),
294    }
295}