predict.rs

  1use crate::{
  2    PredictionProvider, PromptFormat,
  3    anthropic_client::AnthropicClient,
  4    example::{Example, ExamplePrediction},
  5    format_prompt::{TeacherPrompt, run_format_prompt},
  6    headless::EpAppState,
  7    load_project::run_load_project,
  8    paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
  9    progress::{InfoStyle, Progress, Step},
 10    retrieve_context::run_context_retrieval,
 11};
 12use anyhow::Context as _;
 13use edit_prediction::{DebugEvent, EditPredictionStore};
 14use futures::{FutureExt as _, StreamExt as _, future::Shared};
 15use gpui::{AppContext as _, AsyncApp, Task};
 16use std::{
 17    fs,
 18    sync::{
 19        Arc, Mutex, OnceLock,
 20        atomic::{AtomicUsize, Ordering::SeqCst},
 21    },
 22};
 23
 24pub async fn run_prediction(
 25    example: &mut Example,
 26    provider: Option<PredictionProvider>,
 27    repetition_count: usize,
 28    app_state: Arc<EpAppState>,
 29    mut cx: AsyncApp,
 30) -> anyhow::Result<()> {
 31    let provider = provider.context("provider is required")?;
 32
 33    if let Some(existing_prediction) = example.predictions.first() {
 34        if existing_prediction.provider == provider {
 35            return Ok(());
 36        } else {
 37            example.predictions.clear();
 38        }
 39    }
 40
 41    run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
 42
 43    if matches!(
 44        provider,
 45        PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
 46    ) {
 47        let _step_progress = Progress::global().start(Step::Predict, &example.spec.name);
 48
 49        if example.prompt.is_none() {
 50            run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?;
 51        }
 52
 53        let batched = matches!(provider, PredictionProvider::Teacher);
 54        return predict_anthropic(example, repetition_count, batched).await;
 55    }
 56
 57    run_load_project(example, app_state.clone(), cx.clone()).await?;
 58
 59    let _step_progress = Progress::global().start(Step::Predict, &example.spec.name);
 60
 61    if matches!(
 62        provider,
 63        PredictionProvider::Zeta1 | PredictionProvider::Zeta2
 64    ) {
 65        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
 66        AUTHENTICATED
 67            .get_or_init(|| {
 68                let client = app_state.client.clone();
 69                cx.spawn(async move |cx| {
 70                    if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
 71                        eprintln!("Authentication failed: {}", e);
 72                    }
 73                })
 74                .shared()
 75            })
 76            .clone()
 77            .await;
 78    }
 79
 80    let ep_store = cx.update(|cx| {
 81        EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
 82    })??;
 83
 84    ep_store.update(&mut cx, |store, _cx| {
 85        let model = match provider {
 86            PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
 87            PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
 88            PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
 89            PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
 90            PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
 91                unreachable!()
 92            }
 93        };
 94        store.set_edit_prediction_model(model);
 95    })?;
 96    let state = example.state.as_ref().context("state must be set")?;
 97    let run_dir = RUN_DIR.join(&example.spec.name);
 98
 99    let updated_example = Arc::new(Mutex::new(example.clone()));
100    let current_run_ix = Arc::new(AtomicUsize::new(0));
101
102    let mut debug_rx =
103        ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))?;
104    let debug_task = cx.background_spawn({
105        let updated_example = updated_example.clone();
106        let current_run_ix = current_run_ix.clone();
107        let run_dir = run_dir.clone();
108        async move {
109            while let Some(event) = debug_rx.next().await {
110                let run_ix = current_run_ix.load(SeqCst);
111                let mut updated_example = updated_example.lock().unwrap();
112
113                let run_dir = if repetition_count > 1 {
114                    run_dir.join(format!("{:03}", run_ix))
115                } else {
116                    run_dir.clone()
117                };
118
119                match event {
120                    DebugEvent::EditPredictionStarted(request) => {
121                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
122
123                        if let Some(prompt) = request.prompt {
124                            fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
125                        }
126                    }
127                    DebugEvent::EditPredictionFinished(request) => {
128                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
129
130                        if let Some(output) = request.model_output {
131                            fs::write(run_dir.join("prediction_response.md"), &output)?;
132                            updated_example
133                                .predictions
134                                .last_mut()
135                                .unwrap()
136                                .actual_output = output;
137                        }
138                        if run_ix >= repetition_count {
139                            break;
140                        }
141                    }
142                    _ => {}
143                }
144            }
145            anyhow::Ok(())
146        }
147    });
148
149    for ix in 0..repetition_count {
150        current_run_ix.store(ix, SeqCst);
151        let run_dir = if repetition_count > 1 {
152            run_dir.join(format!("{:03}", ix))
153        } else {
154            run_dir.clone()
155        };
156
157        fs::create_dir_all(&run_dir)?;
158        if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
159            fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
160        }
161        #[cfg(unix)]
162        std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
163        #[cfg(windows)]
164        std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
165
166        updated_example
167            .lock()
168            .unwrap()
169            .predictions
170            .push(ExamplePrediction {
171                actual_patch: String::new(),
172                actual_output: String::new(),
173                provider,
174            });
175
176        let prediction = ep_store
177            .update(&mut cx, |store, cx| {
178                store.request_prediction(
179                    &state.project,
180                    &state.buffer,
181                    state.cursor_position,
182                    cloud_llm_client::PredictEditsRequestTrigger::Cli,
183                    cx,
184                )
185            })?
186            .await?;
187
188        let actual_patch = prediction
189            .and_then(|prediction| {
190                let prediction = prediction.prediction.ok()?;
191                prediction
192                    .edit_preview
193                    .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
194            })
195            .unwrap_or_default();
196
197        let has_prediction = !actual_patch.is_empty();
198
199        updated_example
200            .lock()
201            .unwrap()
202            .predictions
203            .last_mut()
204            .unwrap()
205            .actual_patch = actual_patch;
206
207        if ix == repetition_count - 1 {
208            let (info, style) = if has_prediction {
209                ("predicted", InfoStyle::Normal)
210            } else {
211                ("no prediction", InfoStyle::Warning)
212            };
213            _step_progress.set_info(info, style);
214        }
215    }
216
217    ep_store.update(&mut cx, |store, _| {
218        store.remove_project(&state.project);
219    })?;
220    debug_task.await?;
221
222    *example = Arc::into_inner(updated_example)
223        .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
224        .into_inner()
225        .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
226    Ok(())
227}
228
229async fn predict_anthropic(
230    example: &mut Example,
231    _repetition_count: usize,
232    batched: bool,
233) -> anyhow::Result<()> {
234    let llm_model_name = "claude-sonnet-4-5";
235    let max_tokens = 16384;
236    let llm_client = if batched {
237        AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref())
238    } else {
239        AnthropicClient::plain()
240    };
241    let llm_client = llm_client.context("Failed to create LLM client")?;
242
243    let prompt = example.prompt.as_ref().context("Prompt is required")?;
244
245    let messages = vec![anthropic::Message {
246        role: anthropic::Role::User,
247        content: vec![anthropic::RequestContent::Text {
248            text: prompt.input.clone(),
249            cache_control: None,
250        }],
251    }];
252
253    let Some(response) = llm_client
254        .generate(llm_model_name, max_tokens, messages)
255        .await?
256    else {
257        // Request stashed for batched processing
258        return Ok(());
259    };
260
261    let actual_output = response
262        .content
263        .into_iter()
264        .filter_map(|content| match content {
265            anthropic::ResponseContent::Text { text } => Some(text),
266            _ => None,
267        })
268        .collect::<Vec<String>>()
269        .join("\n");
270
271    let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
272
273    let prediction = ExamplePrediction {
274        actual_patch,
275        actual_output,
276        provider: PredictionProvider::Teacher,
277    };
278
279    example.predictions.push(prediction);
280    Ok(())
281}
282
283pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
284    match provider {
285        PredictionProvider::Teacher => {
286            let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
287            let llm_client =
288                AnthropicClient::batch(cache_path).context("Failed to create LLM client")?;
289            llm_client
290                .sync_batches()
291                .await
292                .context("Failed to sync batches")?;
293        }
294        _ => (),
295    };
296    Ok(())
297}