predict.rs

  1use crate::{
  2    PredictionProvider, PromptFormat,
  3    anthropic_client::AnthropicClient,
  4    example::{Example, ExamplePrediction},
  5    format_prompt::{TeacherPrompt, run_format_prompt},
  6    headless::EpAppState,
  7    load_project::run_load_project,
  8    paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
  9    progress::{InfoStyle, Progress, Step},
 10    retrieve_context::run_context_retrieval,
 11};
 12use anyhow::Context as _;
 13use edit_prediction::{DebugEvent, EditPredictionStore};
 14use futures::{FutureExt as _, StreamExt as _, future::Shared};
 15use gpui::{AppContext as _, AsyncApp, Task};
 16use std::{
 17    fs,
 18    sync::{
 19        Arc, Mutex, OnceLock,
 20        atomic::{AtomicUsize, Ordering::SeqCst},
 21    },
 22};
 23
 24pub async fn run_prediction(
 25    example: &mut Example,
 26    provider: Option<PredictionProvider>,
 27    repetition_count: usize,
 28    app_state: Arc<EpAppState>,
 29    mut cx: AsyncApp,
 30) -> anyhow::Result<()> {
 31    if !example.predictions.is_empty() {
 32        return Ok(());
 33    }
 34
 35    let provider = provider.context("provider is required")?;
 36
 37    run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
 38
 39    if matches!(
 40        provider,
 41        PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
 42    ) {
 43        let _step_progress = Progress::global().start(Step::Predict, &example.spec.name);
 44
 45        if example.prompt.is_none() {
 46            run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?;
 47        }
 48
 49        let batched = matches!(provider, PredictionProvider::Teacher);
 50        return predict_anthropic(example, repetition_count, batched).await;
 51    }
 52
 53    run_load_project(example, app_state.clone(), cx.clone()).await?;
 54
 55    let _step_progress = Progress::global().start(Step::Predict, &example.spec.name);
 56
 57    if matches!(
 58        provider,
 59        PredictionProvider::Zeta1 | PredictionProvider::Zeta2
 60    ) {
 61        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
 62        AUTHENTICATED
 63            .get_or_init(|| {
 64                let client = app_state.client.clone();
 65                cx.spawn(async move |cx| {
 66                    if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
 67                        eprintln!("Authentication failed: {}", e);
 68                    }
 69                })
 70                .shared()
 71            })
 72            .clone()
 73            .await;
 74    }
 75
 76    let ep_store = cx.update(|cx| {
 77        EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
 78    })??;
 79
 80    ep_store.update(&mut cx, |store, _cx| {
 81        let model = match provider {
 82            PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
 83            PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
 84            PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
 85            PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
 86            PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
 87                unreachable!()
 88            }
 89        };
 90        store.set_edit_prediction_model(model);
 91    })?;
 92    let state = example.state.as_ref().context("state must be set")?;
 93    let run_dir = RUN_DIR.join(&example.spec.name);
 94
 95    let updated_example = Arc::new(Mutex::new(example.clone()));
 96    let current_run_ix = Arc::new(AtomicUsize::new(0));
 97
 98    let mut debug_rx =
 99        ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))?;
100    let debug_task = cx.background_spawn({
101        let updated_example = updated_example.clone();
102        let current_run_ix = current_run_ix.clone();
103        let run_dir = run_dir.clone();
104        async move {
105            while let Some(event) = debug_rx.next().await {
106                let run_ix = current_run_ix.load(SeqCst);
107                let mut updated_example = updated_example.lock().unwrap();
108
109                let run_dir = if repetition_count > 1 {
110                    run_dir.join(format!("{:03}", run_ix))
111                } else {
112                    run_dir.clone()
113                };
114
115                match event {
116                    DebugEvent::EditPredictionStarted(request) => {
117                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
118
119                        if let Some(prompt) = request.prompt {
120                            fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
121                        }
122                    }
123                    DebugEvent::EditPredictionFinished(request) => {
124                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
125
126                        if let Some(output) = request.model_output {
127                            fs::write(run_dir.join("prediction_response.md"), &output)?;
128                            updated_example
129                                .predictions
130                                .last_mut()
131                                .unwrap()
132                                .actual_output = output;
133                        }
134                        if run_ix >= repetition_count {
135                            break;
136                        }
137                    }
138                    _ => {}
139                }
140            }
141            anyhow::Ok(())
142        }
143    });
144
145    for ix in 0..repetition_count {
146        current_run_ix.store(ix, SeqCst);
147        let run_dir = if repetition_count > 1 {
148            run_dir.join(format!("{:03}", ix))
149        } else {
150            run_dir.clone()
151        };
152
153        fs::create_dir_all(&run_dir)?;
154        if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
155            fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
156        }
157        #[cfg(unix)]
158        std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
159        #[cfg(windows)]
160        std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
161
162        updated_example
163            .lock()
164            .unwrap()
165            .predictions
166            .push(ExamplePrediction {
167                actual_patch: String::new(),
168                actual_output: String::new(),
169                provider,
170            });
171
172        let prediction = ep_store
173            .update(&mut cx, |store, cx| {
174                store.request_prediction(
175                    &state.project,
176                    &state.buffer,
177                    state.cursor_position,
178                    cloud_llm_client::PredictEditsRequestTrigger::Cli,
179                    cx,
180                )
181            })?
182            .await?;
183
184        let actual_patch = prediction
185            .and_then(|prediction| {
186                let prediction = prediction.prediction.ok()?;
187                prediction
188                    .edit_preview
189                    .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
190            })
191            .unwrap_or_default();
192
193        let has_prediction = !actual_patch.is_empty();
194
195        updated_example
196            .lock()
197            .unwrap()
198            .predictions
199            .last_mut()
200            .unwrap()
201            .actual_patch = actual_patch;
202
203        if ix == repetition_count - 1 {
204            let (info, style) = if has_prediction {
205                ("predicted", InfoStyle::Normal)
206            } else {
207                ("no prediction", InfoStyle::Warning)
208            };
209            _step_progress.set_info(info, style);
210        }
211    }
212
213    ep_store.update(&mut cx, |store, _| {
214        store.remove_project(&state.project);
215    })?;
216    debug_task.await?;
217
218    *example = Arc::into_inner(updated_example)
219        .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
220        .into_inner()
221        .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
222    Ok(())
223}
224
225async fn predict_anthropic(
226    example: &mut Example,
227    _repetition_count: usize,
228    batched: bool,
229) -> anyhow::Result<()> {
230    let llm_model_name = "claude-sonnet-4-5";
231    let max_tokens = 16384;
232    let llm_client = if batched {
233        AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref())
234    } else {
235        AnthropicClient::plain()
236    };
237    let llm_client = llm_client.context("Failed to create LLM client")?;
238
239    let prompt = example.prompt.as_ref().context("Prompt is required")?;
240
241    let messages = vec![anthropic::Message {
242        role: anthropic::Role::User,
243        content: vec![anthropic::RequestContent::Text {
244            text: prompt.input.clone(),
245            cache_control: None,
246        }],
247    }];
248
249    let Some(response) = llm_client
250        .generate(llm_model_name, max_tokens, messages)
251        .await?
252    else {
253        // Request stashed for batched processing
254        return Ok(());
255    };
256
257    let actual_output = response
258        .content
259        .into_iter()
260        .filter_map(|content| match content {
261            anthropic::ResponseContent::Text { text } => Some(text),
262            _ => None,
263        })
264        .collect::<Vec<String>>()
265        .join("\n");
266
267    let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
268
269    let prediction = ExamplePrediction {
270        actual_patch,
271        actual_output,
272        provider: PredictionProvider::Teacher,
273    };
274
275    example.predictions.push(prediction);
276    Ok(())
277}
278
279pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
280    match provider {
281        PredictionProvider::Teacher => {
282            let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
283            let llm_client =
284                AnthropicClient::batch(cache_path).context("Failed to create LLM client")?;
285            llm_client
286                .sync_batches()
287                .await
288                .context("Failed to sync batches")?;
289        }
290        _ => (),
291    };
292    Ok(())
293}