predict.rs

  1use crate::{
  2    PredictionProvider, PromptFormat,
  3    anthropic_client::AnthropicClient,
  4    example::{Example, ExamplePrediction},
  5    format_prompt::{TeacherPrompt, run_format_prompt},
  6    headless::EpAppState,
  7    load_project::run_load_project,
  8    paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
  9    progress::{InfoStyle, Progress, Step},
 10    retrieve_context::run_context_retrieval,
 11};
 12use edit_prediction::{DebugEvent, EditPredictionStore};
 13use futures::{FutureExt as _, StreamExt as _, future::Shared};
 14use gpui::{AppContext as _, AsyncApp, Task};
 15use std::{
 16    fs,
 17    sync::{
 18        Arc, Mutex, OnceLock,
 19        atomic::{AtomicUsize, Ordering::SeqCst},
 20    },
 21};
 22
 23pub async fn run_prediction(
 24    example: &mut Example,
 25    provider: Option<PredictionProvider>,
 26    repetition_count: usize,
 27    app_state: Arc<EpAppState>,
 28    progress: Arc<Progress>,
 29    mut cx: AsyncApp,
 30) {
 31    if !example.predictions.is_empty() {
 32        return;
 33    }
 34
 35    let provider = provider.unwrap();
 36
 37    run_context_retrieval(example, app_state.clone(), progress.clone(), cx.clone()).await;
 38
 39    if matches!(
 40        provider,
 41        PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
 42    ) {
 43        let _step_progress = progress.start(Step::Predict, &example.name);
 44
 45        if example.prompt.is_none() {
 46            run_format_prompt(
 47                example,
 48                PromptFormat::Teacher,
 49                app_state.clone(),
 50                progress,
 51                cx,
 52            )
 53            .await;
 54        }
 55
 56        let batched = matches!(provider, PredictionProvider::Teacher);
 57        return predict_anthropic(example, repetition_count, batched).await;
 58    }
 59
 60    run_load_project(example, app_state.clone(), progress.clone(), cx.clone()).await;
 61
 62    let _step_progress = progress.start(Step::Predict, &example.name);
 63
 64    if matches!(
 65        provider,
 66        PredictionProvider::Zeta1 | PredictionProvider::Zeta2
 67    ) {
 68        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
 69        AUTHENTICATED
 70            .get_or_init(|| {
 71                let client = app_state.client.clone();
 72                cx.spawn(async move |cx| {
 73                    client
 74                        .sign_in_with_optional_connect(true, cx)
 75                        .await
 76                        .unwrap();
 77                })
 78                .shared()
 79            })
 80            .clone()
 81            .await;
 82    }
 83
 84    let ep_store = cx
 85        .update(|cx| EditPredictionStore::try_global(cx).unwrap())
 86        .unwrap();
 87
 88    ep_store
 89        .update(&mut cx, |store, _cx| {
 90            let model = match provider {
 91                PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
 92                PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
 93                PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
 94                PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
 95                PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
 96                    unreachable!()
 97                }
 98            };
 99            store.set_edit_prediction_model(model);
100        })
101        .unwrap();
102    let state = example.state.as_ref().unwrap();
103    let run_dir = RUN_DIR.join(&example.name);
104
105    let updated_example = Arc::new(Mutex::new(example.clone()));
106    let current_run_ix = Arc::new(AtomicUsize::new(0));
107
108    let mut debug_rx = ep_store
109        .update(&mut cx, |store, cx| store.debug_info(&state.project, cx))
110        .unwrap();
111    let debug_task = cx.background_spawn({
112        let updated_example = updated_example.clone();
113        let current_run_ix = current_run_ix.clone();
114        let run_dir = run_dir.clone();
115        async move {
116            while let Some(event) = debug_rx.next().await {
117                let run_ix = current_run_ix.load(SeqCst);
118                let mut updated_example = updated_example.lock().unwrap();
119
120                let run_dir = if repetition_count > 1 {
121                    run_dir.join(format!("{:03}", run_ix))
122                } else {
123                    run_dir.clone()
124                };
125
126                match event {
127                    DebugEvent::EditPredictionStarted(request) => {
128                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
129
130                        if let Some(prompt) = request.prompt {
131                            fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
132                        }
133                    }
134                    DebugEvent::EditPredictionFinished(request) => {
135                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
136
137                        if let Some(output) = request.model_output {
138                            fs::write(run_dir.join("prediction_response.md"), &output)?;
139                            updated_example
140                                .predictions
141                                .last_mut()
142                                .unwrap()
143                                .actual_output = output;
144                        }
145                        if run_ix >= repetition_count {
146                            break;
147                        }
148                    }
149                    _ => {}
150                }
151            }
152            anyhow::Ok(())
153        }
154    });
155
156    for ix in 0..repetition_count {
157        current_run_ix.store(ix, SeqCst);
158        let run_dir = if repetition_count > 1 {
159            run_dir.join(format!("{:03}", ix))
160        } else {
161            run_dir.clone()
162        };
163
164        fs::create_dir_all(&run_dir).unwrap();
165        if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
166            fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap();
167        }
168        #[cfg(unix)]
169        std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
170        #[cfg(windows)]
171        std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
172
173        updated_example
174            .lock()
175            .unwrap()
176            .predictions
177            .push(ExamplePrediction {
178                actual_patch: String::new(),
179                actual_output: String::new(),
180                provider,
181            });
182
183        let prediction = ep_store
184            .update(&mut cx, |store, cx| {
185                store.request_prediction(
186                    &state.project,
187                    &state.buffer,
188                    state.cursor_position,
189                    cloud_llm_client::PredictEditsRequestTrigger::Cli,
190                    cx,
191                )
192            })
193            .unwrap()
194            .await
195            .unwrap();
196
197        let actual_patch = prediction
198            .and_then(|prediction| {
199                let prediction = prediction.prediction.ok()?;
200                prediction.edit_preview.as_unified_diff(&prediction.edits)
201            })
202            .unwrap_or_default();
203
204        let has_prediction = !actual_patch.is_empty();
205
206        updated_example
207            .lock()
208            .unwrap()
209            .predictions
210            .last_mut()
211            .unwrap()
212            .actual_patch = actual_patch;
213
214        if ix == repetition_count - 1 {
215            let (info, style) = if has_prediction {
216                ("predicted", InfoStyle::Normal)
217            } else {
218                ("no prediction", InfoStyle::Warning)
219            };
220            _step_progress.set_info(info, style);
221        }
222    }
223
224    ep_store
225        .update(&mut cx, |store, _| {
226            store.remove_project(&state.project);
227        })
228        .unwrap();
229    debug_task.await.unwrap();
230
231    *example = Arc::into_inner(updated_example)
232        .unwrap()
233        .into_inner()
234        .unwrap();
235}
236
237async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) {
238    let llm_model_name = "claude-sonnet-4-5";
239    let max_tokens = 16384;
240    let llm_client = if batched {
241        AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref())
242    } else {
243        AnthropicClient::plain()
244    };
245    let llm_client = llm_client.expect("Failed to create LLM client");
246
247    let prompt = example
248        .prompt
249        .as_ref()
250        .unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name));
251
252    let messages = vec![anthropic::Message {
253        role: anthropic::Role::User,
254        content: vec![anthropic::RequestContent::Text {
255            text: prompt.input.clone(),
256            cache_control: None,
257        }],
258    }];
259
260    let Some(response) = llm_client
261        .generate(llm_model_name, max_tokens, messages)
262        .await
263        .unwrap()
264    else {
265        // Request stashed for batched processing
266        return;
267    };
268
269    let actual_output = response
270        .content
271        .into_iter()
272        .filter_map(|content| match content {
273            anthropic::ResponseContent::Text { text } => Some(text),
274            _ => None,
275        })
276        .collect::<Vec<String>>()
277        .join("\n");
278
279    let actual_patch = TeacherPrompt::parse(example, &actual_output);
280
281    let prediction = ExamplePrediction {
282        actual_patch,
283        actual_output,
284        provider: PredictionProvider::Teacher,
285    };
286
287    example.predictions.push(prediction);
288}
289
290pub async fn sync_batches(provider: &PredictionProvider) {
291    match provider {
292        PredictionProvider::Teacher => {
293            let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
294            let llm_client =
295                AnthropicClient::batch(cache_path).expect("Failed to create LLM client");
296            llm_client
297                .sync_batches()
298                .await
299                .expect("Failed to sync batches");
300        }
301        _ => (),
302    }
303}