predict.rs

  1use crate::{
  2    FormatPromptArgs, PredictArgs, PredictionProvider,
  3    anthropic_client::AnthropicClient,
  4    example::{Example, ExamplePrediction, ExamplePrompt},
  5    format_prompt::{TeacherPrompt, run_format_prompt},
  6    headless::EpAppState,
  7    load_project::run_load_project,
  8    paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
  9    progress::{InfoStyle, Progress, Step},
 10    retrieve_context::run_context_retrieval,
 11};
 12use anyhow::Context as _;
 13use edit_prediction::{DebugEvent, EditPredictionStore};
 14use futures::{FutureExt as _, StreamExt as _, future::Shared};
 15use gpui::{AppContext as _, AsyncApp, Task};
 16use std::{
 17    fs,
 18    sync::{
 19        Arc, Mutex, OnceLock,
 20        atomic::{AtomicUsize, Ordering::SeqCst},
 21    },
 22};
 23
 24static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
 25
 26pub async fn run_prediction(
 27    example: &mut Example,
 28    args: &PredictArgs,
 29    app_state: Arc<EpAppState>,
 30    mut cx: AsyncApp,
 31) -> anyhow::Result<()> {
 32    let provider = args.provider;
 33    let repetition_count = args.repetitions;
 34    let zeta_version = args.version;
 35
 36    if let Some(existing_prediction) = example.predictions.first() {
 37        if existing_prediction.provider == provider {
 38            return Ok(());
 39        } else {
 40            example.predictions.clear();
 41        }
 42    }
 43
 44    run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
 45
 46    if matches!(
 47        provider,
 48        PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
 49    ) {
 50        let _step_progress = Progress::global().start(Step::Predict, &example.spec.name);
 51
 52        run_format_prompt(
 53            example,
 54            &FormatPromptArgs {
 55                provider,
 56                version: args.version,
 57            },
 58            app_state.clone(),
 59            cx,
 60        )
 61        .await?;
 62
 63        let batched = matches!(provider, PredictionProvider::Teacher);
 64        return predict_anthropic(example, repetition_count, batched).await;
 65    }
 66
 67    run_load_project(example, app_state.clone(), cx.clone()).await?;
 68
 69    let step_progress = Progress::global().start(Step::Predict, &example.spec.name);
 70
 71    if matches!(
 72        provider,
 73        PredictionProvider::Zeta1 | PredictionProvider::Zeta2
 74    ) {
 75        step_progress.set_substatus("authenticating");
 76        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
 77        AUTHENTICATED
 78            .get_or_init(|| {
 79                let client = app_state.client.clone();
 80                cx.spawn(async move |cx| {
 81                    if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
 82                        eprintln!("Authentication failed: {}", e);
 83                    }
 84                })
 85                .shared()
 86            })
 87            .clone()
 88            .await;
 89    }
 90
 91    let ep_store = cx
 92        .update(|cx| EditPredictionStore::try_global(cx))
 93        .context("EditPredictionStore not initialized")?;
 94
 95    ep_store.update(&mut cx, |store, _cx| {
 96        let model = match provider {
 97            PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
 98            PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2 {
 99                version: zeta_version,
100            },
101            PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
102            PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
103            PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
104                unreachable!()
105            }
106        };
107        store.set_edit_prediction_model(model);
108    });
109    step_progress.set_substatus("configuring model");
110    let state = example.state.as_ref().context("state must be set")?;
111    let run_dir = RUN_DIR.join(&example.spec.name);
112
113    let updated_example = Arc::new(Mutex::new(example.clone()));
114    let current_run_ix = Arc::new(AtomicUsize::new(0));
115
116    let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
117    let debug_task = cx.background_spawn({
118        let updated_example = updated_example.clone();
119        let current_run_ix = current_run_ix.clone();
120        let run_dir = run_dir.clone();
121        async move {
122            while let Some(event) = debug_rx.next().await {
123                let run_ix = current_run_ix.load(SeqCst);
124                let mut updated_example = updated_example.lock().unwrap();
125
126                let run_dir = if repetition_count > 1 {
127                    run_dir.join(format!("{:03}", run_ix))
128                } else {
129                    run_dir.clone()
130                };
131
132                match event {
133                    DebugEvent::EditPredictionStarted(request) => {
134                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
135
136                        if let Some(prompt) = request.prompt {
137                            fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
138                            if provider == PredictionProvider::Zeta2 {
139                                updated_example.prompt.get_or_insert(ExamplePrompt {
140                                    input: prompt,
141                                    expected_output: String::new(),
142                                    provider,
143                                });
144                            }
145                        }
146                    }
147                    DebugEvent::EditPredictionFinished(request) => {
148                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
149
150                        if let Some(output) = request.model_output {
151                            fs::write(run_dir.join("prediction_response.md"), &output)?;
152                            updated_example
153                                .predictions
154                                .last_mut()
155                                .unwrap()
156                                .actual_output = output;
157                        }
158                        if run_ix >= repetition_count {
159                            break;
160                        }
161                    }
162                    _ => {}
163                }
164            }
165            anyhow::Ok(())
166        }
167    });
168
169    for ix in 0..repetition_count {
170        current_run_ix.store(ix, SeqCst);
171        let run_dir = if repetition_count > 1 {
172            run_dir.join(format!("{:03}", ix))
173        } else {
174            run_dir.clone()
175        };
176
177        fs::create_dir_all(&run_dir)?;
178        if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
179            fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
180        }
181        #[cfg(unix)]
182        std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
183        #[cfg(windows)]
184        std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
185
186        updated_example
187            .lock()
188            .unwrap()
189            .predictions
190            .push(ExamplePrediction {
191                actual_patch: String::new(),
192                actual_output: String::new(),
193                provider,
194            });
195
196        step_progress.set_substatus("requesting prediction");
197        let prediction = ep_store
198            .update(&mut cx, |store, cx| {
199                store.request_prediction(
200                    &state.project,
201                    &state.buffer,
202                    state.cursor_position,
203                    cloud_llm_client::PredictEditsRequestTrigger::Cli,
204                    cx,
205                )
206            })
207            .await?;
208
209        let actual_patch = prediction
210            .and_then(|prediction| {
211                let prediction = prediction.prediction.ok()?;
212                prediction
213                    .edit_preview
214                    .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
215            })
216            .unwrap_or_default();
217
218        let has_prediction = !actual_patch.is_empty();
219
220        updated_example
221            .lock()
222            .unwrap()
223            .predictions
224            .last_mut()
225            .unwrap()
226            .actual_patch = actual_patch;
227
228        if ix == repetition_count - 1 {
229            let (info, style) = if has_prediction {
230                ("predicted", InfoStyle::Normal)
231            } else {
232                ("no prediction", InfoStyle::Warning)
233            };
234            step_progress.set_info(info, style);
235        }
236    }
237
238    ep_store.update(&mut cx, |store, _| {
239        store.remove_project(&state.project);
240    });
241    debug_task.await?;
242
243    *example = Arc::into_inner(updated_example)
244        .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
245        .into_inner()
246        .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
247    Ok(())
248}
249
250async fn predict_anthropic(
251    example: &mut Example,
252    _repetition_count: usize,
253    batched: bool,
254) -> anyhow::Result<()> {
255    let llm_model_name = "claude-sonnet-4-5";
256    let max_tokens = 16384;
257    let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
258        let client = if batched {
259            AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
260        } else {
261            AnthropicClient::plain()
262        };
263        client.expect("Failed to create Anthropic client")
264    });
265
266    let prompt = example.prompt.as_ref().context("Prompt is required")?;
267
268    let messages = vec![anthropic::Message {
269        role: anthropic::Role::User,
270        content: vec![anthropic::RequestContent::Text {
271            text: prompt.input.clone(),
272            cache_control: None,
273        }],
274    }];
275
276    let Some(response) = llm_client
277        .generate(llm_model_name, max_tokens, messages)
278        .await?
279    else {
280        // Request stashed for batched processing
281        return Ok(());
282    };
283
284    let actual_output = response
285        .content
286        .into_iter()
287        .filter_map(|content| match content {
288            anthropic::ResponseContent::Text { text } => Some(text),
289            _ => None,
290        })
291        .collect::<Vec<String>>()
292        .join("\n");
293
294    let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
295
296    let prediction = ExamplePrediction {
297        actual_patch,
298        actual_output,
299        provider: PredictionProvider::Teacher,
300    };
301
302    example.predictions.push(prediction);
303    Ok(())
304}
305
306pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
307    match provider {
308        PredictionProvider::Teacher => {
309            let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
310                AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
311                    .expect("Failed to create Anthropic client")
312            });
313            llm_client
314                .sync_batches()
315                .await
316                .context("Failed to sync batches")?;
317        }
318        _ => (),
319    };
320    Ok(())
321}