predict.rs

  1use crate::{
  2    FormatPromptArgs, PredictArgs, PredictionProvider,
  3    anthropic_client::AnthropicClient,
  4    example::{Example, ExamplePrediction, ExamplePrompt},
  5    format_prompt::{TeacherPrompt, run_format_prompt},
  6    headless::EpAppState,
  7    load_project::run_load_project,
  8    paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
  9    progress::{InfoStyle, Progress, Step},
 10    retrieve_context::run_context_retrieval,
 11};
 12use anyhow::Context as _;
 13use edit_prediction::{DebugEvent, EditPredictionStore};
 14use futures::{FutureExt as _, StreamExt as _, future::Shared};
 15use gpui::{AppContext as _, AsyncApp, Task};
 16use std::{
 17    fs,
 18    sync::{
 19        Arc, Mutex, OnceLock,
 20        atomic::{AtomicUsize, Ordering::SeqCst},
 21    },
 22};
 23use zeta_prompt::ZetaVersion;
 24
 25static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
 26
 27pub async fn run_prediction(
 28    example: &mut Example,
 29    args: &PredictArgs,
 30    app_state: Arc<EpAppState>,
 31    mut cx: AsyncApp,
 32) -> anyhow::Result<()> {
 33    let provider = args.provider;
 34    let repetition_count = args.repetitions;
 35
 36    if let Some(existing_prediction) = example.predictions.first() {
 37        if existing_prediction.provider == provider {
 38            return Ok(());
 39        } else {
 40            example.predictions.clear();
 41        }
 42    }
 43
 44    run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
 45
 46    if let PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) =
 47        args.provider
 48    {
 49        let _step_progress = Progress::global().start(Step::Predict, &example.spec.name);
 50
 51        run_format_prompt(
 52            example,
 53            &FormatPromptArgs { provider },
 54            app_state.clone(),
 55            cx,
 56        )
 57        .await?;
 58
 59        let batched = matches!(provider, PredictionProvider::Teacher(..));
 60        return predict_anthropic(example, repetition_count, version, batched).await;
 61    }
 62
 63    run_load_project(example, app_state.clone(), cx.clone()).await?;
 64
 65    let step_progress = Progress::global().start(Step::Predict, &example.spec.name);
 66
 67    if matches!(
 68        provider,
 69        PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
 70    ) {
 71        step_progress.set_substatus("authenticating");
 72        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
 73        AUTHENTICATED
 74            .get_or_init(|| {
 75                let client = app_state.client.clone();
 76                cx.spawn(async move |cx| {
 77                    if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
 78                        eprintln!("Authentication failed: {}", e);
 79                    }
 80                })
 81                .shared()
 82            })
 83            .clone()
 84            .await;
 85    }
 86
 87    let ep_store = cx
 88        .update(|cx| EditPredictionStore::try_global(cx))
 89        .context("EditPredictionStore not initialized")?;
 90
 91    ep_store.update(&mut cx, |store, _cx| {
 92        let model = match provider {
 93            PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
 94            PredictionProvider::Zeta2(version) => {
 95                edit_prediction::EditPredictionModel::Zeta2 { version }
 96            }
 97            PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
 98            PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
 99            PredictionProvider::Teacher(..) | PredictionProvider::TeacherNonBatching(..) => {
100                unreachable!()
101            }
102        };
103        store.set_edit_prediction_model(model);
104    });
105    step_progress.set_substatus("configuring model");
106    let state = example.state.as_ref().context("state must be set")?;
107    let run_dir = RUN_DIR.join(&example.spec.name);
108
109    let updated_example = Arc::new(Mutex::new(example.clone()));
110    let current_run_ix = Arc::new(AtomicUsize::new(0));
111
112    let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
113    let debug_task = cx.background_spawn({
114        let updated_example = updated_example.clone();
115        let current_run_ix = current_run_ix.clone();
116        let run_dir = run_dir.clone();
117        async move {
118            while let Some(event) = debug_rx.next().await {
119                let run_ix = current_run_ix.load(SeqCst);
120                let mut updated_example = updated_example.lock().unwrap();
121
122                let run_dir = if repetition_count > 1 {
123                    run_dir.join(format!("{:03}", run_ix))
124                } else {
125                    run_dir.clone()
126                };
127
128                match event {
129                    DebugEvent::EditPredictionStarted(request) => {
130                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
131
132                        if let Some(prompt) = request.prompt {
133                            fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
134                            if matches!(provider, PredictionProvider::Zeta2(_)) {
135                                updated_example.prompt.get_or_insert(ExamplePrompt {
136                                    input: prompt,
137                                    expected_output: String::new(),
138                                    provider,
139                                });
140                            }
141                        }
142                    }
143                    DebugEvent::EditPredictionFinished(request) => {
144                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
145
146                        if let Some(output) = request.model_output {
147                            fs::write(run_dir.join("prediction_response.md"), &output)?;
148                            updated_example
149                                .predictions
150                                .last_mut()
151                                .unwrap()
152                                .actual_output = output;
153                        }
154                        if run_ix >= repetition_count {
155                            break;
156                        }
157                    }
158                    _ => {}
159                }
160            }
161            anyhow::Ok(())
162        }
163    });
164
165    for ix in 0..repetition_count {
166        current_run_ix.store(ix, SeqCst);
167        let run_dir = if repetition_count > 1 {
168            run_dir.join(format!("{:03}", ix))
169        } else {
170            run_dir.clone()
171        };
172
173        fs::create_dir_all(&run_dir)?;
174        if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
175            fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
176        }
177        #[cfg(unix)]
178        std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
179        #[cfg(windows)]
180        std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
181
182        updated_example
183            .lock()
184            .unwrap()
185            .predictions
186            .push(ExamplePrediction {
187                actual_patch: String::new(),
188                actual_output: String::new(),
189                provider,
190            });
191
192        step_progress.set_substatus("requesting prediction");
193        let prediction = ep_store
194            .update(&mut cx, |store, cx| {
195                store.request_prediction(
196                    &state.project,
197                    &state.buffer,
198                    state.cursor_position,
199                    cloud_llm_client::PredictEditsRequestTrigger::Cli,
200                    cx,
201                )
202            })
203            .await?;
204
205        let actual_patch = prediction
206            .and_then(|prediction| {
207                let prediction = prediction.prediction.ok()?;
208                prediction
209                    .edit_preview
210                    .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
211            })
212            .unwrap_or_default();
213
214        let has_prediction = !actual_patch.is_empty();
215
216        updated_example
217            .lock()
218            .unwrap()
219            .predictions
220            .last_mut()
221            .unwrap()
222            .actual_patch = actual_patch;
223
224        if ix == repetition_count - 1 {
225            let (info, style) = if has_prediction {
226                ("predicted", InfoStyle::Normal)
227            } else {
228                ("no prediction", InfoStyle::Warning)
229            };
230            step_progress.set_info(info, style);
231        }
232    }
233
234    ep_store.update(&mut cx, |store, _| {
235        store.remove_project(&state.project);
236    });
237    debug_task.await?;
238
239    *example = Arc::into_inner(updated_example)
240        .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
241        .into_inner()
242        .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
243    Ok(())
244}
245
246async fn predict_anthropic(
247    example: &mut Example,
248    _repetition_count: usize,
249    version: ZetaVersion,
250    batched: bool,
251) -> anyhow::Result<()> {
252    let llm_model_name = "claude-sonnet-4-5";
253    let max_tokens = 16384;
254    let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
255        let client = if batched {
256            AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
257        } else {
258            AnthropicClient::plain()
259        };
260        client.expect("Failed to create Anthropic client")
261    });
262
263    let prompt = example.prompt.as_ref().context("Prompt is required")?;
264
265    let messages = vec![anthropic::Message {
266        role: anthropic::Role::User,
267        content: vec![anthropic::RequestContent::Text {
268            text: prompt.input.clone(),
269            cache_control: None,
270        }],
271    }];
272
273    let Some(response) = llm_client
274        .generate(llm_model_name, max_tokens, messages)
275        .await?
276    else {
277        // Request stashed for batched processing
278        return Ok(());
279    };
280
281    let actual_output = response
282        .content
283        .into_iter()
284        .filter_map(|content| match content {
285            anthropic::ResponseContent::Text { text } => Some(text),
286            _ => None,
287        })
288        .collect::<Vec<String>>()
289        .join("\n");
290
291    let actual_patch = TeacherPrompt::parse(&example, &actual_output)?;
292
293    let prediction = ExamplePrediction {
294        actual_patch,
295        actual_output,
296        provider: if batched {
297            PredictionProvider::Teacher(version)
298        } else {
299            PredictionProvider::TeacherNonBatching(version)
300        },
301    };
302
303    example.predictions.push(prediction);
304    Ok(())
305}
306
307pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
308    match provider {
309        PredictionProvider::Teacher(..) => {
310            let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
311                AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
312                    .expect("Failed to create Anthropic client")
313            });
314            llm_client
315                .sync_batches()
316                .await
317                .context("Failed to sync batches")?;
318        }
319        _ => (),
320    };
321    Ok(())
322}