From e1e356d81cf35f73ede2c84b721b72f9a51a652c Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Tue, 27 Jan 2026 11:08:32 -0800 Subject: [PATCH] Make --repetions flag work for teacher --- .../src/anthropic_client.rs | 31 ++- .../edit_prediction_cli/src/openai_client.rs | 26 ++- crates/edit_prediction_cli/src/predict.rs | 179 ++++++++++-------- crates/edit_prediction_cli/src/qa.rs | 4 +- crates/edit_prediction_cli/src/repair.rs | 4 +- 5 files changed, 143 insertions(+), 101 deletions(-) diff --git a/crates/edit_prediction_cli/src/anthropic_client.rs b/crates/edit_prediction_cli/src/anthropic_client.rs index c37e141f1352a2b8a25ba089aa84ac1769fb476d..242a38dd6104c1c173d5bd978fc8b41d69f1edff 100644 --- a/crates/edit_prediction_cli/src/anthropic_client.rs +++ b/crates/edit_prediction_cli/src/anthropic_client.rs @@ -208,8 +208,9 @@ impl BatchingLlmClient { model: &str, max_tokens: u64, messages: &[Message], + seed: Option, ) -> Result> { - let request_hash_str = Self::request_hash(model, max_tokens, messages); + let request_hash_str = Self::request_hash(model, max_tokens, messages, seed); let connection = self.connection.lock().unwrap(); let response: Vec = connection.select_bound( &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;), @@ -220,8 +221,14 @@ impl BatchingLlmClient { .and_then(|text| serde_json::from_str(&text).ok())) } - pub fn mark_for_batch(&self, model: &str, max_tokens: u64, messages: &[Message]) -> Result<()> { - let request_hash = Self::request_hash(model, max_tokens, messages); + pub fn mark_for_batch( + &self, + model: &str, + max_tokens: u64, + messages: &[Message], + seed: Option, + ) -> Result<()> { + let request_hash = Self::request_hash(model, max_tokens, messages, seed); let serializable_messages: Vec = messages .iter() @@ -259,13 +266,14 @@ impl BatchingLlmClient { model: &str, max_tokens: u64, messages: Vec, + seed: Option, ) -> Result> { - let response = self.lookup(model, max_tokens, &messages)?; + let response = self.lookup(model, max_tokens, &messages, seed)?; if let Some(response) = response { return Ok(Some(response)); } - self.mark_for_batch(model, max_tokens, &messages)?; + self.mark_for_batch(model, max_tokens, &messages, seed)?; Ok(None) } @@ -606,13 +614,21 @@ impl BatchingLlmClient { Ok(all_batch_ids) } - fn request_hash(model: &str, max_tokens: u64, messages: &[Message]) -> String { + fn request_hash( + model: &str, + max_tokens: u64, + messages: &[Message], + seed: Option, + ) -> String { let mut hasher = std::hash::DefaultHasher::new(); model.hash(&mut hasher); max_tokens.hash(&mut hasher); for msg in messages { message_content_to_string(&msg.content).hash(&mut hasher); } + if let Some(seed) = seed { + seed.hash(&mut hasher); + } let request_hash = hasher.finish(); format!("{request_hash:016x}") } @@ -655,6 +671,7 @@ impl AnthropicClient { model: &str, max_tokens: u64, messages: Vec, + seed: Option, ) -> Result> { match self { AnthropicClient::Plain(plain_llm_client) => plain_llm_client @@ -663,7 +680,7 @@ impl AnthropicClient { .map(Some), AnthropicClient::Batch(batching_llm_client) => { batching_llm_client - .generate(model, max_tokens, messages) + .generate(model, max_tokens, messages, seed) .await } AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"), diff --git a/crates/edit_prediction_cli/src/openai_client.rs b/crates/edit_prediction_cli/src/openai_client.rs index ad402c3472b238d0d822c9564d449c63d581fa53..15d31c997183aa7a5dc12669a7016be8eb9a6ad4 100644 --- a/crates/edit_prediction_cli/src/openai_client.rs +++ b/crates/edit_prediction_cli/src/openai_client.rs @@ -138,8 +138,9 @@ impl BatchingOpenAiClient { model: &str, max_tokens: u64, messages: &[RequestMessage], + seed: Option, ) -> Result> { - let request_hash_str = Self::request_hash(model, max_tokens, messages); + let request_hash_str = Self::request_hash(model, max_tokens, messages, seed); let connection = self.connection.lock().unwrap(); let response: Vec = connection.select_bound( &sql!(SELECT response FROM openai_cache WHERE request_hash = ?1 AND response IS NOT NULL;), @@ -155,8 +156,9 @@ impl BatchingOpenAiClient { model: &str, max_tokens: u64, messages: &[RequestMessage], + seed: Option, ) -> Result<()> { - let request_hash = Self::request_hash(model, max_tokens, messages); + let request_hash = Self::request_hash(model, max_tokens, messages, seed); let serializable_messages: Vec = messages .iter() @@ -191,13 +193,14 @@ impl BatchingOpenAiClient { model: &str, max_tokens: u64, messages: Vec, + seed: Option, ) -> Result> { - let response = self.lookup(model, max_tokens, &messages)?; + let response = self.lookup(model, max_tokens, &messages, seed)?; if let Some(response) = response { return Ok(Some(response)); } - self.mark_for_batch(model, max_tokens, &messages)?; + self.mark_for_batch(model, max_tokens, &messages, seed)?; Ok(None) } @@ -558,7 +561,12 @@ impl BatchingOpenAiClient { Ok(all_batch_ids) } - fn request_hash(model: &str, max_tokens: u64, messages: &[RequestMessage]) -> String { + fn request_hash( + model: &str, + max_tokens: u64, + messages: &[RequestMessage], + seed: Option, + ) -> String { let mut hasher = std::hash::DefaultHasher::new(); "openai".hash(&mut hasher); model.hash(&mut hasher); @@ -566,6 +574,9 @@ impl BatchingOpenAiClient { for msg in messages { message_content_to_string(msg).hash(&mut hasher); } + if let Some(seed) = seed { + seed.hash(&mut hasher); + } let request_hash = hasher.finish(); format!("{request_hash:016x}") } @@ -631,6 +642,7 @@ impl OpenAiClient { model: &str, max_tokens: u64, messages: Vec, + seed: Option, ) -> Result> { match self { OpenAiClient::Plain(plain_client) => plain_client @@ -638,7 +650,9 @@ impl OpenAiClient { .await .map(Some), OpenAiClient::Batch(batching_client) => { - batching_client.generate(model, max_tokens, messages).await + batching_client + .generate(model, max_tokens, messages, seed) + .await } OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"), } diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index c0b6af3f71de6ee3134e449c7db50b129c3b221b..19c2591b4fe3a1fdede82269da37af170ea4d2d7 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -69,7 +69,7 @@ pub async fn run_prediction( .await?; let batched = matches!(provider, PredictionProvider::Teacher(..)); - return predict_teacher(example, backend, batched).await; + return predict_teacher(example, backend, batched, repetition_count).await; } run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?; @@ -261,10 +261,13 @@ async fn predict_teacher( example: &mut Example, backend: TeacherBackend, batched: bool, + repetition_count: usize, ) -> anyhow::Result<()> { match backend { - TeacherBackend::Sonnet45 => predict_anthropic(example, backend, batched).await, - TeacherBackend::Gpt52 => predict_openai(example, backend, batched).await, + TeacherBackend::Sonnet45 => { + predict_anthropic(example, backend, batched, repetition_count).await + } + TeacherBackend::Gpt52 => predict_openai(example, backend, batched, repetition_count).await, } } @@ -272,6 +275,7 @@ async fn predict_anthropic( example: &mut Example, backend: TeacherBackend, batched: bool, + repetition_count: usize, ) -> anyhow::Result<()> { let llm_model_name = backend.model_name(); let max_tokens = 16384; @@ -286,46 +290,49 @@ async fn predict_anthropic( let prompt = example.prompt.as_ref().context("Prompt is required")?; - let messages = vec![anthropic::Message { - role: anthropic::Role::User, - content: vec![anthropic::RequestContent::Text { - text: prompt.input.clone(), - cache_control: None, - }], - }]; - - let Some(response) = llm_client - .generate(llm_model_name, max_tokens, messages) - .await? - else { - // Request stashed for batched processing - return Ok(()); - }; + for ix in 0..repetition_count { + let messages = vec![anthropic::Message { + role: anthropic::Role::User, + content: vec![anthropic::RequestContent::Text { + text: prompt.input.clone(), + cache_control: None, + }], + }]; + + let seed = if repetition_count > 1 { Some(ix) } else { None }; + let Some(response) = llm_client + .generate(llm_model_name, max_tokens, messages, seed) + .await? + else { + // Request stashed for batched processing + return Ok(()); + }; - let actual_output = response - .content - .into_iter() - .filter_map(|content| match content { - anthropic::ResponseContent::Text { text } => Some(text), - _ => None, - }) - .collect::>() - .join("\n"); - - let actual_patch = TeacherPrompt::parse(example, &actual_output)?; - - let prediction = ExamplePrediction { - actual_patch: Some(actual_patch), - actual_output, - error: None, - provider: if batched { - PredictionProvider::Teacher(backend) - } else { - PredictionProvider::TeacherNonBatching(backend) - }, - }; + let actual_output = response + .content + .into_iter() + .filter_map(|content| match content { + anthropic::ResponseContent::Text { text } => Some(text), + _ => None, + }) + .collect::>() + .join("\n"); + + let actual_patch = TeacherPrompt::parse(example, &actual_output)?; + + let prediction = ExamplePrediction { + actual_patch: Some(actual_patch), + actual_output, + error: None, + provider: if batched { + PredictionProvider::Teacher(backend) + } else { + PredictionProvider::TeacherNonBatching(backend) + }, + }; - example.predictions.push(prediction); + example.predictions.push(prediction); + } Ok(()) } @@ -333,6 +340,7 @@ async fn predict_openai( example: &mut Example, backend: TeacherBackend, batched: bool, + repetition_count: usize, ) -> anyhow::Result<()> { let llm_model_name = backend.model_name(); let max_tokens = 16384; @@ -347,52 +355,55 @@ async fn predict_openai( let prompt = example.prompt.as_ref().context("Prompt is required")?; - let messages = vec![open_ai::RequestMessage::User { - content: open_ai::MessageContent::Plain(prompt.input.clone()), - }]; + for ix in 0..repetition_count { + let messages = vec![open_ai::RequestMessage::User { + content: open_ai::MessageContent::Plain(prompt.input.clone()), + }]; + + let seed = if repetition_count > 1 { Some(ix) } else { None }; + let Some(response) = llm_client + .generate(llm_model_name, max_tokens, messages, seed) + .await? + else { + // Request stashed for batched processing + return Ok(()); + }; - let Some(response) = llm_client - .generate(llm_model_name, max_tokens, messages) - .await? - else { - // Request stashed for batched processing - return Ok(()); - }; + let actual_output = response + .choices + .into_iter() + .filter_map(|choice| match choice.message { + open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c { + open_ai::MessageContent::Plain(text) => text, + open_ai::MessageContent::Multipart(parts) => parts + .into_iter() + .filter_map(|p| match p { + open_ai::MessagePart::Text { text } => Some(text), + _ => None, + }) + .collect::>() + .join(""), + }), + _ => None, + }) + .collect::>() + .join("\n"); - let actual_output = response - .choices - .into_iter() - .filter_map(|choice| match choice.message { - open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c { - open_ai::MessageContent::Plain(text) => text, - open_ai::MessageContent::Multipart(parts) => parts - .into_iter() - .filter_map(|p| match p { - open_ai::MessagePart::Text { text } => Some(text), - _ => None, - }) - .collect::>() - .join(""), - }), - _ => None, - }) - .collect::>() - .join("\n"); - - let actual_patch = TeacherPrompt::parse(example, &actual_output)?; - - let prediction = ExamplePrediction { - actual_patch: Some(actual_patch), - actual_output, - error: None, - provider: if batched { - PredictionProvider::Teacher(backend) - } else { - PredictionProvider::TeacherNonBatching(backend) - }, - }; + let actual_patch = TeacherPrompt::parse(example, &actual_output)?; + + let prediction = ExamplePrediction { + actual_patch: Some(actual_patch), + actual_output, + error: None, + provider: if batched { + PredictionProvider::Teacher(backend) + } else { + PredictionProvider::TeacherNonBatching(backend) + }, + }; - example.predictions.push(prediction); + example.predictions.push(prediction); + } Ok(()) } diff --git a/crates/edit_prediction_cli/src/qa.rs b/crates/edit_prediction_cli/src/qa.rs index d8ed61e9e1f4c8823a188e8917e74cc04042fec3..28a592c2b875303d59087e3fe5e0e7d176ee74c2 100644 --- a/crates/edit_prediction_cli/src/qa.rs +++ b/crates/edit_prediction_cli/src/qa.rs @@ -172,7 +172,7 @@ impl QaClient { cache_control: None, }], }]; - let response = client.generate(model, max_tokens, messages).await?; + let response = client.generate(model, max_tokens, messages, None).await?; Ok(response.map(|r| { r.content .iter() @@ -188,7 +188,7 @@ impl QaClient { let messages = vec![open_ai::RequestMessage::User { content: open_ai::MessageContent::Plain(prompt.to_string()), }]; - let response = client.generate(model, max_tokens, messages).await?; + let response = client.generate(model, max_tokens, messages, None).await?; Ok(response.map(|r| { r.choices .into_iter() diff --git a/crates/edit_prediction_cli/src/repair.rs b/crates/edit_prediction_cli/src/repair.rs index 2f8de97d8cdf0126314a03ad47c52f2815f41639..e78420d2d5f4bface31d1bd85e21165b38226f76 100644 --- a/crates/edit_prediction_cli/src/repair.rs +++ b/crates/edit_prediction_cli/src/repair.rs @@ -152,7 +152,7 @@ impl RepairClient { cache_control: None, }], }]; - let response = client.generate(model, max_tokens, messages).await?; + let response = client.generate(model, max_tokens, messages, None).await?; Ok(response.map(|r| { r.content .iter() @@ -168,7 +168,7 @@ impl RepairClient { let messages = vec![open_ai::RequestMessage::User { content: open_ai::MessageContent::Plain(prompt.to_string()), }]; - let response = client.generate(model, max_tokens, messages).await?; + let response = client.generate(model, max_tokens, messages, None).await?; Ok(response.map(|r| { r.choices .into_iter()