predict.rs

  1use crate::{
  2    PredictionProvider, PromptFormat,
  3    anthropic_client::AnthropicClient,
  4    example::{Example, ExamplePrediction},
  5    format_prompt::{TeacherPrompt, run_format_prompt},
  6    headless::EpAppState,
  7    load_project::run_load_project,
  8    paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
  9    retrieve_context::run_context_retrieval,
 10};
 11use edit_prediction::{DebugEvent, EditPredictionStore};
 12use futures::{FutureExt as _, StreamExt as _, future::Shared};
 13use gpui::{AppContext as _, AsyncApp, Task};
 14use std::{
 15    fs,
 16    sync::{
 17        Arc, Mutex, OnceLock,
 18        atomic::{AtomicUsize, Ordering::SeqCst},
 19    },
 20};
 21
 22pub async fn run_prediction(
 23    example: &mut Example,
 24    provider: Option<PredictionProvider>,
 25    repetition_count: usize,
 26    app_state: Arc<EpAppState>,
 27    mut cx: AsyncApp,
 28) {
 29    if !example.predictions.is_empty() {
 30        return;
 31    }
 32
 33    run_context_retrieval(example, app_state.clone(), cx.clone()).await;
 34
 35    let provider = provider.unwrap();
 36
 37    if matches!(
 38        provider,
 39        PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
 40    ) {
 41        if example.prompt.is_none() {
 42            run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await;
 43        }
 44
 45        let batched = matches!(provider, PredictionProvider::Teacher);
 46        return predict_anthropic(example, repetition_count, batched).await;
 47    }
 48
 49    run_load_project(example, app_state.clone(), cx.clone()).await;
 50
 51    if matches!(
 52        provider,
 53        PredictionProvider::Zeta1 | PredictionProvider::Zeta2
 54    ) {
 55        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
 56        AUTHENTICATED
 57            .get_or_init(|| {
 58                let client = app_state.client.clone();
 59                cx.spawn(async move |cx| {
 60                    client
 61                        .sign_in_with_optional_connect(true, cx)
 62                        .await
 63                        .unwrap();
 64                })
 65                .shared()
 66            })
 67            .clone()
 68            .await;
 69    }
 70
 71    let ep_store = cx
 72        .update(|cx| EditPredictionStore::try_global(cx).unwrap())
 73        .unwrap();
 74
 75    ep_store
 76        .update(&mut cx, |store, _cx| {
 77            let model = match provider {
 78                PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
 79                PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
 80                PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
 81                PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
 82                PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
 83                    unreachable!()
 84                }
 85            };
 86            store.set_edit_prediction_model(model);
 87        })
 88        .unwrap();
 89    let state = example.state.as_ref().unwrap();
 90    let run_dir = RUN_DIR.join(&example.name);
 91
 92    let updated_example = Arc::new(Mutex::new(example.clone()));
 93    let current_run_ix = Arc::new(AtomicUsize::new(0));
 94
 95    let mut debug_rx = ep_store
 96        .update(&mut cx, |store, cx| store.debug_info(&state.project, cx))
 97        .unwrap();
 98    let debug_task = cx.background_spawn({
 99        let updated_example = updated_example.clone();
100        let current_run_ix = current_run_ix.clone();
101        let run_dir = run_dir.clone();
102        async move {
103            while let Some(event) = debug_rx.next().await {
104                let run_ix = current_run_ix.load(SeqCst);
105                let mut updated_example = updated_example.lock().unwrap();
106
107                let run_dir = if repetition_count > 1 {
108                    run_dir.join(format!("{:03}", run_ix))
109                } else {
110                    run_dir.clone()
111                };
112
113                match event {
114                    DebugEvent::EditPredictionStarted(request) => {
115                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
116
117                        if let Some(prompt) = request.prompt {
118                            fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
119                        }
120                    }
121                    DebugEvent::EditPredictionFinished(request) => {
122                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
123
124                        if let Some(output) = request.model_output {
125                            fs::write(run_dir.join("prediction_response.md"), &output)?;
126                            updated_example
127                                .predictions
128                                .last_mut()
129                                .unwrap()
130                                .actual_output = output;
131                        }
132                        if run_ix >= repetition_count {
133                            break;
134                        }
135                    }
136                    _ => {}
137                }
138            }
139            anyhow::Ok(())
140        }
141    });
142
143    for ix in 0..repetition_count {
144        current_run_ix.store(ix, SeqCst);
145        let run_dir = if repetition_count > 1 {
146            run_dir.join(format!("{:03}", ix))
147        } else {
148            run_dir.clone()
149        };
150
151        fs::create_dir_all(&run_dir).unwrap();
152        if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
153            fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap();
154        }
155        #[cfg(unix)]
156        std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
157        #[cfg(windows)]
158        std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
159
160        updated_example
161            .lock()
162            .unwrap()
163            .predictions
164            .push(ExamplePrediction {
165                actual_patch: String::new(),
166                actual_output: String::new(),
167                provider,
168            });
169
170        let prediction = ep_store
171            .update(&mut cx, |store, cx| {
172                store.request_prediction(
173                    &state.project,
174                    &state.buffer,
175                    state.cursor_position,
176                    cloud_llm_client::PredictEditsRequestTrigger::Cli,
177                    cx,
178                )
179            })
180            .unwrap()
181            .await
182            .unwrap();
183
184        updated_example
185            .lock()
186            .unwrap()
187            .predictions
188            .last_mut()
189            .unwrap()
190            .actual_patch = prediction
191            .and_then(|prediction| {
192                let prediction = prediction.prediction.ok()?;
193                prediction.edit_preview.as_unified_diff(&prediction.edits)
194            })
195            .unwrap_or_default();
196    }
197
198    ep_store
199        .update(&mut cx, |store, _| {
200            store.remove_project(&state.project);
201        })
202        .unwrap();
203    debug_task.await.unwrap();
204
205    *example = Arc::into_inner(updated_example)
206        .unwrap()
207        .into_inner()
208        .unwrap();
209}
210
211async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) {
212    let llm_model_name = "claude-sonnet-4-5";
213    let max_tokens = 16384;
214    let llm_client = if batched {
215        AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref())
216    } else {
217        AnthropicClient::plain()
218    };
219    let llm_client = llm_client.expect("Failed to create LLM client");
220
221    let prompt = example
222        .prompt
223        .as_ref()
224        .unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name));
225
226    let messages = vec![anthropic::Message {
227        role: anthropic::Role::User,
228        content: vec![anthropic::RequestContent::Text {
229            text: prompt.input.clone(),
230            cache_control: None,
231        }],
232    }];
233
234    let Some(response) = llm_client
235        .generate(llm_model_name, max_tokens, messages)
236        .await
237        .unwrap()
238    else {
239        // Request stashed for batched processing
240        return;
241    };
242
243    let actual_output = response
244        .content
245        .into_iter()
246        .filter_map(|content| match content {
247            anthropic::ResponseContent::Text { text } => Some(text),
248            _ => None,
249        })
250        .collect::<Vec<String>>()
251        .join("\n");
252
253    let actual_patch = TeacherPrompt::parse(example, &actual_output);
254
255    let prediction = ExamplePrediction {
256        actual_patch,
257        actual_output,
258        provider: PredictionProvider::Teacher,
259    };
260
261    example.predictions.push(prediction);
262}
263
264pub async fn sync_batches(provider: &PredictionProvider) {
265    match provider {
266        PredictionProvider::Teacher => {
267            let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
268            let llm_client =
269                AnthropicClient::batch(cache_path).expect("Failed to create LLM client");
270            llm_client
271                .sync_batches()
272                .await
273                .expect("Failed to sync batches");
274        }
275        _ => (),
276    }
277}