predict.rs

  1use crate::{
  2    PredictionProvider, PromptFormat,
  3    anthropic_client::AnthropicClient,
  4    example::{Example, ExamplePrediction},
  5    format_prompt::{PromptParser, TeacherPrompt, run_format_prompt},
  6    headless::EpAppState,
  7    load_project::run_load_project,
  8    paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
  9    retrieve_context::run_context_retrieval,
 10};
 11use edit_prediction::{DebugEvent, EditPredictionStore};
 12use futures::{FutureExt as _, StreamExt as _, future::Shared};
 13use gpui::{AppContext as _, AsyncApp, Task};
 14use std::{
 15    fs,
 16    sync::{
 17        Arc, Mutex, OnceLock,
 18        atomic::{AtomicUsize, Ordering::SeqCst},
 19    },
 20};
 21
 22pub async fn run_prediction(
 23    example: &mut Example,
 24    provider: Option<PredictionProvider>,
 25    repetition_count: usize,
 26    app_state: Arc<EpAppState>,
 27    mut cx: AsyncApp,
 28) {
 29    if !example.predictions.is_empty() {
 30        return;
 31    }
 32
 33    run_load_project(example, app_state.clone(), cx.clone()).await;
 34    run_context_retrieval(example, app_state.clone(), cx.clone()).await;
 35
 36    let provider = provider.unwrap();
 37
 38    if matches!(provider, PredictionProvider::Teacher) {
 39        if example.prompt.is_none() {
 40            run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await;
 41        }
 42
 43        let batched = true;
 44        return predict_anthropic(example, repetition_count, batched).await;
 45    }
 46
 47    if matches!(
 48        provider,
 49        PredictionProvider::Zeta1 | PredictionProvider::Zeta2
 50    ) {
 51        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
 52        AUTHENTICATED
 53            .get_or_init(|| {
 54                let client = app_state.client.clone();
 55                cx.spawn(async move |cx| {
 56                    client
 57                        .sign_in_with_optional_connect(true, cx)
 58                        .await
 59                        .unwrap();
 60                })
 61                .shared()
 62            })
 63            .clone()
 64            .await;
 65    }
 66
 67    let ep_store = cx
 68        .update(|cx| EditPredictionStore::try_global(cx).unwrap())
 69        .unwrap();
 70
 71    ep_store
 72        .update(&mut cx, |store, _cx| {
 73            let model = match provider {
 74                PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
 75                PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
 76                PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
 77                PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
 78                PredictionProvider::Teacher => unreachable!(),
 79            };
 80            store.set_edit_prediction_model(model);
 81        })
 82        .unwrap();
 83    let state = example.state.as_ref().unwrap();
 84    let run_dir = RUN_DIR.join(&example.name);
 85
 86    let updated_example = Arc::new(Mutex::new(example.clone()));
 87    let current_run_ix = Arc::new(AtomicUsize::new(0));
 88
 89    let mut debug_rx = ep_store
 90        .update(&mut cx, |store, cx| store.debug_info(&state.project, cx))
 91        .unwrap();
 92    let debug_task = cx.background_spawn({
 93        let updated_example = updated_example.clone();
 94        let current_run_ix = current_run_ix.clone();
 95        let run_dir = run_dir.clone();
 96        async move {
 97            while let Some(event) = debug_rx.next().await {
 98                let run_ix = current_run_ix.load(SeqCst);
 99                let mut updated_example = updated_example.lock().unwrap();
100
101                let run_dir = if repetition_count > 1 {
102                    run_dir.join(format!("{:03}", run_ix))
103                } else {
104                    run_dir.clone()
105                };
106
107                match event {
108                    DebugEvent::EditPredictionStarted(request) => {
109                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
110
111                        if let Some(prompt) = request.prompt {
112                            fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
113                        }
114                    }
115                    DebugEvent::EditPredictionFinished(request) => {
116                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
117
118                        if let Some(output) = request.model_output {
119                            fs::write(run_dir.join("prediction_response.md"), &output)?;
120                            updated_example
121                                .predictions
122                                .last_mut()
123                                .unwrap()
124                                .actual_output = output;
125                        }
126                        if run_ix >= repetition_count {
127                            break;
128                        }
129                    }
130                    _ => {}
131                }
132            }
133            anyhow::Ok(())
134        }
135    });
136
137    for ix in 0..repetition_count {
138        current_run_ix.store(ix, SeqCst);
139        let run_dir = if repetition_count > 1 {
140            run_dir.join(format!("{:03}", ix))
141        } else {
142            run_dir.clone()
143        };
144
145        fs::create_dir_all(&run_dir).unwrap();
146        if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
147            fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap();
148        }
149        #[cfg(unix)]
150        std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
151        #[cfg(windows)]
152        std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
153
154        updated_example
155            .lock()
156            .unwrap()
157            .predictions
158            .push(ExamplePrediction {
159                actual_patch: String::new(),
160                actual_output: String::new(),
161                provider,
162            });
163
164        let prediction = ep_store
165            .update(&mut cx, |store, cx| {
166                store.request_prediction(
167                    &state.project,
168                    &state.buffer,
169                    state.cursor_position,
170                    cloud_llm_client::PredictEditsRequestTrigger::Cli,
171                    cx,
172                )
173            })
174            .unwrap()
175            .await
176            .unwrap();
177
178        updated_example
179            .lock()
180            .unwrap()
181            .predictions
182            .last_mut()
183            .unwrap()
184            .actual_patch = prediction
185            .and_then(|prediction| {
186                let prediction = prediction.prediction.ok()?;
187                prediction.edit_preview.as_unified_diff(&prediction.edits)
188            })
189            .unwrap_or_default();
190    }
191
192    ep_store
193        .update(&mut cx, |store, _| {
194            store.remove_project(&state.project);
195        })
196        .unwrap();
197    debug_task.await.unwrap();
198
199    *example = Arc::into_inner(updated_example)
200        .unwrap()
201        .into_inner()
202        .unwrap();
203}
204
205async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) {
206    let llm_model_name = "claude-sonnet-4-5";
207    let max_tokens = 16384;
208    let llm_client = if batched {
209        AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref())
210    } else {
211        AnthropicClient::plain()
212    };
213    let llm_client = llm_client.expect("Failed to create LLM client");
214
215    let prompt = example
216        .prompt
217        .as_ref()
218        .unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name));
219
220    let messages = vec![anthropic::Message {
221        role: anthropic::Role::User,
222        content: vec![anthropic::RequestContent::Text {
223            text: prompt.input.clone(),
224            cache_control: None,
225        }],
226    }];
227
228    let Some(response) = llm_client
229        .generate(llm_model_name, max_tokens, messages)
230        .await
231        .unwrap()
232    else {
233        // Request stashed for batched processing
234        return;
235    };
236
237    let actual_output = response
238        .content
239        .into_iter()
240        .filter_map(|content| match content {
241            anthropic::ResponseContent::Text { text } => Some(text),
242            _ => None,
243        })
244        .collect::<Vec<String>>()
245        .join("\n");
246
247    let actual_patch = TeacherPrompt::parse(example, &actual_output);
248
249    let prediction = ExamplePrediction {
250        actual_patch,
251        actual_output,
252        provider: PredictionProvider::Teacher,
253    };
254
255    example.predictions.push(prediction);
256}
257
258pub async fn sync_batches(provider: &PredictionProvider) {
259    match provider {
260        PredictionProvider::Teacher => {
261            let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
262            let llm_client =
263                AnthropicClient::batch(cache_path).expect("Failed to create LLM client");
264            llm_client
265                .sync_batches()
266                .await
267                .expect("Failed to sync batches");
268        }
269        _ => (),
270    }
271}