predict.rs

  1use crate::{
  2    FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
  3    anthropic_client::AnthropicClient,
  4    example::{Example, ExamplePrediction, ExamplePrompt},
  5    format_prompt::{TeacherPrompt, run_format_prompt},
  6    headless::EpAppState,
  7    load_project::run_load_project,
  8    openai_client::OpenAiClient,
  9    paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
 10    progress::{ExampleProgress, InfoStyle, Step},
 11    retrieve_context::run_context_retrieval,
 12};
 13use anyhow::Context as _;
 14use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig};
 15use futures::{FutureExt as _, StreamExt as _, future::Shared};
 16use gpui::{AppContext as _, AsyncApp, Task};
 17use std::{
 18    fs,
 19    sync::{
 20        Arc, Mutex, OnceLock,
 21        atomic::{AtomicUsize, Ordering::SeqCst},
 22    },
 23};
 24use zeta_prompt::ZetaFormat;
 25
 26static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
 27static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
 28
 29pub async fn run_prediction(
 30    example: &mut Example,
 31    args: &PredictArgs,
 32    app_state: Arc<EpAppState>,
 33    example_progress: &ExampleProgress,
 34    mut cx: AsyncApp,
 35) -> anyhow::Result<()> {
 36    let repetition_count = args.repetitions;
 37
 38    if let Some(existing_prediction) = example.predictions.first() {
 39        let has_prediction = existing_prediction.actual_patch.is_some()
 40            || !existing_prediction.actual_output.is_empty();
 41        if has_prediction {
 42            match args.provider {
 43                None => return Ok(()),
 44                Some(provider) if existing_prediction.provider == provider => return Ok(()),
 45                Some(_) => example.predictions.clear(),
 46            }
 47        }
 48    }
 49
 50    let Some(provider) = args.provider else {
 51        anyhow::bail!(
 52            "No existing predictions found. Use --provider to specify which model to use for prediction."
 53        );
 54    };
 55
 56    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 57
 58    if let PredictionProvider::Teacher(backend) | PredictionProvider::TeacherNonBatching(backend) =
 59        provider
 60    {
 61        let _step_progress = example_progress.start(Step::Predict);
 62
 63        run_format_prompt(
 64            example,
 65            &FormatPromptArgs { provider },
 66            app_state.clone(),
 67            example_progress,
 68            cx,
 69        )
 70        .await?;
 71
 72        let batched = matches!(provider, PredictionProvider::Teacher(..));
 73        return predict_teacher(example, backend, batched, repetition_count, args.cache_only).await;
 74    }
 75
 76    run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
 77
 78    let step_progress = example_progress.start(Step::Predict);
 79
 80    if matches!(
 81        provider,
 82        PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
 83    ) {
 84        step_progress.set_substatus("authenticating");
 85        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
 86        AUTHENTICATED
 87            .get_or_init(|| {
 88                let client = app_state.client.clone();
 89                cx.spawn(async move |cx| {
 90                    if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
 91                        eprintln!("Authentication failed: {}", e);
 92                    }
 93                })
 94                .shared()
 95            })
 96            .clone()
 97            .await;
 98    }
 99
100    let ep_store = cx
101        .update(|cx| EditPredictionStore::try_global(cx))
102        .context("EditPredictionStore not initialized")?;
103
104    ep_store.update(&mut cx, |store, _cx| {
105        let model = match provider {
106            PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
107            PredictionProvider::Zeta2(_) => edit_prediction::EditPredictionModel::Zeta2,
108            PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
109            PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
110            PredictionProvider::Teacher(..)
111            | PredictionProvider::TeacherNonBatching(..)
112            | PredictionProvider::Repair => {
113                unreachable!()
114            }
115        };
116        store.set_edit_prediction_model(model);
117
118        // If user specified a non-default Zeta2 version, configure raw endpoint.
119        // ZED_ZETA_MODEL env var is optional.
120        if let PredictionProvider::Zeta2(format) = provider {
121            if format != ZetaFormat::default() {
122                let model_id = std::env::var("ZED_ZETA_MODEL").ok();
123                store.set_zeta2_raw_config(Zeta2RawConfig { model_id, format });
124            }
125        }
126    });
127    step_progress.set_substatus("configuring model");
128    let state = example.state.as_ref().context("state must be set")?;
129    let run_dir = RUN_DIR.join(&example.spec.name);
130
131    let updated_example = Arc::new(Mutex::new(example.clone()));
132    let current_run_ix = Arc::new(AtomicUsize::new(0));
133
134    let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
135    let debug_task = cx.background_spawn({
136        let updated_example = updated_example.clone();
137        let current_run_ix = current_run_ix.clone();
138        let run_dir = run_dir.clone();
139        async move {
140            while let Some(event) = debug_rx.next().await {
141                let run_ix = current_run_ix.load(SeqCst);
142                let mut updated_example = updated_example.lock().unwrap();
143
144                let run_dir = if repetition_count > 1 {
145                    run_dir.join(format!("{:03}", run_ix))
146                } else {
147                    run_dir.clone()
148                };
149
150                match event {
151                    DebugEvent::EditPredictionStarted(request) => {
152                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
153
154                        if let Some(prompt) = request.prompt {
155                            fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
156                            if matches!(provider, PredictionProvider::Zeta2(_)) {
157                                updated_example.prompt.get_or_insert(ExamplePrompt {
158                                    input: prompt,
159                                    expected_output: String::new(),
160                                    rejected_output: None,
161                                    provider,
162                                });
163                            }
164                        }
165                    }
166                    DebugEvent::EditPredictionFinished(request) => {
167                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
168
169                        if let Some(output) = request.model_output {
170                            fs::write(run_dir.join("prediction_response.md"), &output)?;
171                            updated_example
172                                .predictions
173                                .last_mut()
174                                .unwrap()
175                                .actual_output = output;
176                        }
177                        if run_ix >= repetition_count {
178                            break;
179                        }
180                    }
181                    _ => {}
182                }
183            }
184            anyhow::Ok(())
185        }
186    });
187
188    for ix in 0..repetition_count {
189        current_run_ix.store(ix, SeqCst);
190        let run_dir = if repetition_count > 1 {
191            run_dir.join(format!("{:03}", ix))
192        } else {
193            run_dir.clone()
194        };
195
196        fs::create_dir_all(&run_dir)?;
197        if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
198            fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
199        }
200        #[cfg(unix)]
201        std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
202        #[cfg(windows)]
203        std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
204
205        updated_example
206            .lock()
207            .unwrap()
208            .predictions
209            .push(ExamplePrediction {
210                actual_patch: None,
211                actual_output: String::new(),
212                actual_cursor: None,
213                error: None,
214                provider,
215            });
216
217        step_progress.set_substatus("requesting prediction");
218        let prediction = ep_store
219            .update(&mut cx, |store, cx| {
220                store.request_prediction(
221                    &state.project,
222                    &state.buffer,
223                    state.cursor_position,
224                    cloud_llm_client::PredictEditsRequestTrigger::Cli,
225                    cx,
226                )
227            })
228            .await?;
229
230        let actual_patch = prediction.and_then(|prediction| {
231            let prediction = prediction.prediction.ok()?;
232            prediction
233                .edit_preview
234                .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
235        });
236
237        let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
238
239        updated_example
240            .lock()
241            .unwrap()
242            .predictions
243            .last_mut()
244            .unwrap()
245            .actual_patch = actual_patch;
246
247        if ix == repetition_count - 1 {
248            let (info, style) = if has_prediction {
249                ("predicted", InfoStyle::Normal)
250            } else {
251                ("no prediction", InfoStyle::Warning)
252            };
253            step_progress.set_info(info, style);
254        }
255    }
256
257    ep_store.update(&mut cx, |store, _| {
258        store.remove_project(&state.project);
259    });
260    debug_task.await?;
261
262    *example = Arc::into_inner(updated_example)
263        .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
264        .into_inner()
265        .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
266    Ok(())
267}
268
269async fn predict_teacher(
270    example: &mut Example,
271    backend: TeacherBackend,
272    batched: bool,
273    repetition_count: usize,
274    cache_only: bool,
275) -> anyhow::Result<()> {
276    match backend {
277        TeacherBackend::Sonnet45 => {
278            predict_anthropic(example, backend, batched, repetition_count, cache_only).await
279        }
280        TeacherBackend::Gpt52 => {
281            predict_openai(example, backend, batched, repetition_count, cache_only).await
282        }
283    }
284}
285
286async fn predict_anthropic(
287    example: &mut Example,
288    backend: TeacherBackend,
289    batched: bool,
290    repetition_count: usize,
291    cache_only: bool,
292) -> anyhow::Result<()> {
293    let llm_model_name = backend.model_name();
294    let max_tokens = 16384;
295    let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
296        let client = if batched {
297            AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
298        } else {
299            AnthropicClient::plain()
300        };
301        client.expect("Failed to create Anthropic client")
302    });
303
304    let prompt = example.prompt.as_ref().context("Prompt is required")?;
305
306    for ix in 0..repetition_count {
307        let messages = vec![anthropic::Message {
308            role: anthropic::Role::User,
309            content: vec![anthropic::RequestContent::Text {
310                text: prompt.input.clone(),
311                cache_control: None,
312            }],
313        }];
314
315        let seed = if repetition_count > 1 { Some(ix) } else { None };
316        let Some(response) = llm_client
317            .generate(llm_model_name, max_tokens, messages, seed, cache_only)
318            .await?
319        else {
320            // Request stashed for batched processing
321            return Ok(());
322        };
323
324        let actual_output = response
325            .content
326            .into_iter()
327            .filter_map(|content| match content {
328                anthropic::ResponseContent::Text { text } => Some(text),
329                _ => None,
330            })
331            .collect::<Vec<String>>()
332            .join("\n");
333
334        let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?;
335
336        let prediction = ExamplePrediction {
337            actual_patch: Some(actual_patch),
338            actual_output,
339            actual_cursor,
340            error: None,
341            provider: if batched {
342                PredictionProvider::Teacher(backend)
343            } else {
344                PredictionProvider::TeacherNonBatching(backend)
345            },
346        };
347
348        example.predictions.push(prediction);
349    }
350    Ok(())
351}
352
353async fn predict_openai(
354    example: &mut Example,
355    backend: TeacherBackend,
356    batched: bool,
357    repetition_count: usize,
358    cache_only: bool,
359) -> anyhow::Result<()> {
360    let llm_model_name = backend.model_name();
361    let max_tokens = 16384;
362    let llm_client = OPENAI_CLIENT.get_or_init(|| {
363        let client = if batched {
364            OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
365        } else {
366            OpenAiClient::plain()
367        };
368        client.expect("Failed to create OpenAI client")
369    });
370
371    let prompt = example.prompt.as_ref().context("Prompt is required")?;
372
373    for ix in 0..repetition_count {
374        let messages = vec![open_ai::RequestMessage::User {
375            content: open_ai::MessageContent::Plain(prompt.input.clone()),
376        }];
377
378        let seed = if repetition_count > 1 { Some(ix) } else { None };
379        let Some(response) = llm_client
380            .generate(llm_model_name, max_tokens, messages, seed, cache_only)
381            .await?
382        else {
383            // Request stashed for batched processing
384            return Ok(());
385        };
386
387        let actual_output = response
388            .choices
389            .into_iter()
390            .filter_map(|choice| match choice.message {
391                open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
392                    open_ai::MessageContent::Plain(text) => text,
393                    open_ai::MessageContent::Multipart(parts) => parts
394                        .into_iter()
395                        .filter_map(|p| match p {
396                            open_ai::MessagePart::Text { text } => Some(text),
397                            _ => None,
398                        })
399                        .collect::<Vec<_>>()
400                        .join(""),
401                }),
402                _ => None,
403            })
404            .collect::<Vec<String>>()
405            .join("\n");
406
407        let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?;
408
409        let prediction = ExamplePrediction {
410            actual_patch: Some(actual_patch),
411            actual_output,
412            actual_cursor,
413            error: None,
414            provider: if batched {
415                PredictionProvider::Teacher(backend)
416            } else {
417                PredictionProvider::TeacherNonBatching(backend)
418            },
419        };
420
421        example.predictions.push(prediction);
422    }
423    Ok(())
424}
425
426pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
427    match provider {
428        Some(PredictionProvider::Teacher(backend)) => match backend {
429            TeacherBackend::Sonnet45 => {
430                let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
431                    AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
432                        .expect("Failed to create Anthropic client")
433                });
434                llm_client
435                    .sync_batches()
436                    .await
437                    .context("Failed to sync Anthropic batches")?;
438            }
439            TeacherBackend::Gpt52 => {
440                let llm_client = OPENAI_CLIENT.get_or_init(|| {
441                    OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
442                        .expect("Failed to create OpenAI client")
443                });
444                llm_client
445                    .sync_batches()
446                    .await
447                    .context("Failed to sync OpenAI batches")?;
448            }
449        },
450        _ => (),
451    };
452    Ok(())
453}