Make --repetions flag work for teacher

Max Brunsfeld created

Change summary

crates/edit_prediction_cli/src/anthropic_client.rs |  31 ++
crates/edit_prediction_cli/src/openai_client.rs    |  26 +
crates/edit_prediction_cli/src/predict.rs          | 179 ++++++++-------
crates/edit_prediction_cli/src/qa.rs               |   4 
crates/edit_prediction_cli/src/repair.rs           |   4 
5 files changed, 143 insertions(+), 101 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/anthropic_client.rs 🔗

@@ -208,8 +208,9 @@ impl BatchingLlmClient {
         model: &str,
         max_tokens: u64,
         messages: &[Message],
+        seed: Option<usize>,
     ) -> Result<Option<AnthropicResponse>> {
-        let request_hash_str = Self::request_hash(model, max_tokens, messages);
+        let request_hash_str = Self::request_hash(model, max_tokens, messages, seed);
         let connection = self.connection.lock().unwrap();
         let response: Vec<String> = connection.select_bound(
             &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;),
@@ -220,8 +221,14 @@ impl BatchingLlmClient {
             .and_then(|text| serde_json::from_str(&text).ok()))
     }
 
-    pub fn mark_for_batch(&self, model: &str, max_tokens: u64, messages: &[Message]) -> Result<()> {
-        let request_hash = Self::request_hash(model, max_tokens, messages);
+    pub fn mark_for_batch(
+        &self,
+        model: &str,
+        max_tokens: u64,
+        messages: &[Message],
+        seed: Option<usize>,
+    ) -> Result<()> {
+        let request_hash = Self::request_hash(model, max_tokens, messages, seed);
 
         let serializable_messages: Vec<SerializableMessage> = messages
             .iter()
@@ -259,13 +266,14 @@ impl BatchingLlmClient {
         model: &str,
         max_tokens: u64,
         messages: Vec<Message>,
+        seed: Option<usize>,
     ) -> Result<Option<AnthropicResponse>> {
-        let response = self.lookup(model, max_tokens, &messages)?;
+        let response = self.lookup(model, max_tokens, &messages, seed)?;
         if let Some(response) = response {
             return Ok(Some(response));
         }
 
-        self.mark_for_batch(model, max_tokens, &messages)?;
+        self.mark_for_batch(model, max_tokens, &messages, seed)?;
 
         Ok(None)
     }
@@ -606,13 +614,21 @@ impl BatchingLlmClient {
         Ok(all_batch_ids)
     }
 
-    fn request_hash(model: &str, max_tokens: u64, messages: &[Message]) -> String {
+    fn request_hash(
+        model: &str,
+        max_tokens: u64,
+        messages: &[Message],
+        seed: Option<usize>,
+    ) -> String {
         let mut hasher = std::hash::DefaultHasher::new();
         model.hash(&mut hasher);
         max_tokens.hash(&mut hasher);
         for msg in messages {
             message_content_to_string(&msg.content).hash(&mut hasher);
         }
+        if let Some(seed) = seed {
+            seed.hash(&mut hasher);
+        }
         let request_hash = hasher.finish();
         format!("{request_hash:016x}")
     }
@@ -655,6 +671,7 @@ impl AnthropicClient {
         model: &str,
         max_tokens: u64,
         messages: Vec<Message>,
+        seed: Option<usize>,
     ) -> Result<Option<AnthropicResponse>> {
         match self {
             AnthropicClient::Plain(plain_llm_client) => plain_llm_client
@@ -663,7 +680,7 @@ impl AnthropicClient {
                 .map(Some),
             AnthropicClient::Batch(batching_llm_client) => {
                 batching_llm_client
-                    .generate(model, max_tokens, messages)
+                    .generate(model, max_tokens, messages, seed)
                     .await
             }
             AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),

crates/edit_prediction_cli/src/openai_client.rs 🔗

@@ -138,8 +138,9 @@ impl BatchingOpenAiClient {
         model: &str,
         max_tokens: u64,
         messages: &[RequestMessage],
+        seed: Option<usize>,
     ) -> Result<Option<OpenAiResponse>> {
-        let request_hash_str = Self::request_hash(model, max_tokens, messages);
+        let request_hash_str = Self::request_hash(model, max_tokens, messages, seed);
         let connection = self.connection.lock().unwrap();
         let response: Vec<String> = connection.select_bound(
             &sql!(SELECT response FROM openai_cache WHERE request_hash = ?1 AND response IS NOT NULL;),
@@ -155,8 +156,9 @@ impl BatchingOpenAiClient {
         model: &str,
         max_tokens: u64,
         messages: &[RequestMessage],
+        seed: Option<usize>,
     ) -> Result<()> {
-        let request_hash = Self::request_hash(model, max_tokens, messages);
+        let request_hash = Self::request_hash(model, max_tokens, messages, seed);
 
         let serializable_messages: Vec<SerializableMessage> = messages
             .iter()
@@ -191,13 +193,14 @@ impl BatchingOpenAiClient {
         model: &str,
         max_tokens: u64,
         messages: Vec<RequestMessage>,
+        seed: Option<usize>,
     ) -> Result<Option<OpenAiResponse>> {
-        let response = self.lookup(model, max_tokens, &messages)?;
+        let response = self.lookup(model, max_tokens, &messages, seed)?;
         if let Some(response) = response {
             return Ok(Some(response));
         }
 
-        self.mark_for_batch(model, max_tokens, &messages)?;
+        self.mark_for_batch(model, max_tokens, &messages, seed)?;
 
         Ok(None)
     }
@@ -558,7 +561,12 @@ impl BatchingOpenAiClient {
         Ok(all_batch_ids)
     }
 
-    fn request_hash(model: &str, max_tokens: u64, messages: &[RequestMessage]) -> String {
+    fn request_hash(
+        model: &str,
+        max_tokens: u64,
+        messages: &[RequestMessage],
+        seed: Option<usize>,
+    ) -> String {
         let mut hasher = std::hash::DefaultHasher::new();
         "openai".hash(&mut hasher);
         model.hash(&mut hasher);
@@ -566,6 +574,9 @@ impl BatchingOpenAiClient {
         for msg in messages {
             message_content_to_string(msg).hash(&mut hasher);
         }
+        if let Some(seed) = seed {
+            seed.hash(&mut hasher);
+        }
         let request_hash = hasher.finish();
         format!("{request_hash:016x}")
     }
@@ -631,6 +642,7 @@ impl OpenAiClient {
         model: &str,
         max_tokens: u64,
         messages: Vec<RequestMessage>,
+        seed: Option<usize>,
     ) -> Result<Option<OpenAiResponse>> {
         match self {
             OpenAiClient::Plain(plain_client) => plain_client
@@ -638,7 +650,9 @@ impl OpenAiClient {
                 .await
                 .map(Some),
             OpenAiClient::Batch(batching_client) => {
-                batching_client.generate(model, max_tokens, messages).await
+                batching_client
+                    .generate(model, max_tokens, messages, seed)
+                    .await
             }
             OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
         }

crates/edit_prediction_cli/src/predict.rs 🔗

@@ -69,7 +69,7 @@ pub async fn run_prediction(
         .await?;
 
         let batched = matches!(provider, PredictionProvider::Teacher(..));
-        return predict_teacher(example, backend, batched).await;
+        return predict_teacher(example, backend, batched, repetition_count).await;
     }
 
     run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
@@ -261,10 +261,13 @@ async fn predict_teacher(
     example: &mut Example,
     backend: TeacherBackend,
     batched: bool,
+    repetition_count: usize,
 ) -> anyhow::Result<()> {
     match backend {
-        TeacherBackend::Sonnet45 => predict_anthropic(example, backend, batched).await,
-        TeacherBackend::Gpt52 => predict_openai(example, backend, batched).await,
+        TeacherBackend::Sonnet45 => {
+            predict_anthropic(example, backend, batched, repetition_count).await
+        }
+        TeacherBackend::Gpt52 => predict_openai(example, backend, batched, repetition_count).await,
     }
 }
 
@@ -272,6 +275,7 @@ async fn predict_anthropic(
     example: &mut Example,
     backend: TeacherBackend,
     batched: bool,
+    repetition_count: usize,
 ) -> anyhow::Result<()> {
     let llm_model_name = backend.model_name();
     let max_tokens = 16384;
@@ -286,46 +290,49 @@ async fn predict_anthropic(
 
     let prompt = example.prompt.as_ref().context("Prompt is required")?;
 
-    let messages = vec![anthropic::Message {
-        role: anthropic::Role::User,
-        content: vec![anthropic::RequestContent::Text {
-            text: prompt.input.clone(),
-            cache_control: None,
-        }],
-    }];
-
-    let Some(response) = llm_client
-        .generate(llm_model_name, max_tokens, messages)
-        .await?
-    else {
-        // Request stashed for batched processing
-        return Ok(());
-    };
+    for ix in 0..repetition_count {
+        let messages = vec![anthropic::Message {
+            role: anthropic::Role::User,
+            content: vec![anthropic::RequestContent::Text {
+                text: prompt.input.clone(),
+                cache_control: None,
+            }],
+        }];
+
+        let seed = if repetition_count > 1 { Some(ix) } else { None };
+        let Some(response) = llm_client
+            .generate(llm_model_name, max_tokens, messages, seed)
+            .await?
+        else {
+            // Request stashed for batched processing
+            return Ok(());
+        };
 
-    let actual_output = response
-        .content
-        .into_iter()
-        .filter_map(|content| match content {
-            anthropic::ResponseContent::Text { text } => Some(text),
-            _ => None,
-        })
-        .collect::<Vec<String>>()
-        .join("\n");
-
-    let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
-
-    let prediction = ExamplePrediction {
-        actual_patch: Some(actual_patch),
-        actual_output,
-        error: None,
-        provider: if batched {
-            PredictionProvider::Teacher(backend)
-        } else {
-            PredictionProvider::TeacherNonBatching(backend)
-        },
-    };
+        let actual_output = response
+            .content
+            .into_iter()
+            .filter_map(|content| match content {
+                anthropic::ResponseContent::Text { text } => Some(text),
+                _ => None,
+            })
+            .collect::<Vec<String>>()
+            .join("\n");
+
+        let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
+
+        let prediction = ExamplePrediction {
+            actual_patch: Some(actual_patch),
+            actual_output,
+            error: None,
+            provider: if batched {
+                PredictionProvider::Teacher(backend)
+            } else {
+                PredictionProvider::TeacherNonBatching(backend)
+            },
+        };
 
-    example.predictions.push(prediction);
+        example.predictions.push(prediction);
+    }
     Ok(())
 }
 
@@ -333,6 +340,7 @@ async fn predict_openai(
     example: &mut Example,
     backend: TeacherBackend,
     batched: bool,
+    repetition_count: usize,
 ) -> anyhow::Result<()> {
     let llm_model_name = backend.model_name();
     let max_tokens = 16384;
@@ -347,52 +355,55 @@ async fn predict_openai(
 
     let prompt = example.prompt.as_ref().context("Prompt is required")?;
 
-    let messages = vec![open_ai::RequestMessage::User {
-        content: open_ai::MessageContent::Plain(prompt.input.clone()),
-    }];
+    for ix in 0..repetition_count {
+        let messages = vec![open_ai::RequestMessage::User {
+            content: open_ai::MessageContent::Plain(prompt.input.clone()),
+        }];
+
+        let seed = if repetition_count > 1 { Some(ix) } else { None };
+        let Some(response) = llm_client
+            .generate(llm_model_name, max_tokens, messages, seed)
+            .await?
+        else {
+            // Request stashed for batched processing
+            return Ok(());
+        };
 
-    let Some(response) = llm_client
-        .generate(llm_model_name, max_tokens, messages)
-        .await?
-    else {
-        // Request stashed for batched processing
-        return Ok(());
-    };
+        let actual_output = response
+            .choices
+            .into_iter()
+            .filter_map(|choice| match choice.message {
+                open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
+                    open_ai::MessageContent::Plain(text) => text,
+                    open_ai::MessageContent::Multipart(parts) => parts
+                        .into_iter()
+                        .filter_map(|p| match p {
+                            open_ai::MessagePart::Text { text } => Some(text),
+                            _ => None,
+                        })
+                        .collect::<Vec<_>>()
+                        .join(""),
+                }),
+                _ => None,
+            })
+            .collect::<Vec<String>>()
+            .join("\n");
 
-    let actual_output = response
-        .choices
-        .into_iter()
-        .filter_map(|choice| match choice.message {
-            open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
-                open_ai::MessageContent::Plain(text) => text,
-                open_ai::MessageContent::Multipart(parts) => parts
-                    .into_iter()
-                    .filter_map(|p| match p {
-                        open_ai::MessagePart::Text { text } => Some(text),
-                        _ => None,
-                    })
-                    .collect::<Vec<_>>()
-                    .join(""),
-            }),
-            _ => None,
-        })
-        .collect::<Vec<String>>()
-        .join("\n");
-
-    let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
-
-    let prediction = ExamplePrediction {
-        actual_patch: Some(actual_patch),
-        actual_output,
-        error: None,
-        provider: if batched {
-            PredictionProvider::Teacher(backend)
-        } else {
-            PredictionProvider::TeacherNonBatching(backend)
-        },
-    };
+        let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
+
+        let prediction = ExamplePrediction {
+            actual_patch: Some(actual_patch),
+            actual_output,
+            error: None,
+            provider: if batched {
+                PredictionProvider::Teacher(backend)
+            } else {
+                PredictionProvider::TeacherNonBatching(backend)
+            },
+        };
 
-    example.predictions.push(prediction);
+        example.predictions.push(prediction);
+    }
     Ok(())
 }
 

crates/edit_prediction_cli/src/qa.rs 🔗

@@ -172,7 +172,7 @@ impl QaClient {
                         cache_control: None,
                     }],
                 }];
-                let response = client.generate(model, max_tokens, messages).await?;
+                let response = client.generate(model, max_tokens, messages, None).await?;
                 Ok(response.map(|r| {
                     r.content
                         .iter()
@@ -188,7 +188,7 @@ impl QaClient {
                 let messages = vec![open_ai::RequestMessage::User {
                     content: open_ai::MessageContent::Plain(prompt.to_string()),
                 }];
-                let response = client.generate(model, max_tokens, messages).await?;
+                let response = client.generate(model, max_tokens, messages, None).await?;
                 Ok(response.map(|r| {
                     r.choices
                         .into_iter()

crates/edit_prediction_cli/src/repair.rs 🔗

@@ -152,7 +152,7 @@ impl RepairClient {
                         cache_control: None,
                     }],
                 }];
-                let response = client.generate(model, max_tokens, messages).await?;
+                let response = client.generate(model, max_tokens, messages, None).await?;
                 Ok(response.map(|r| {
                     r.content
                         .iter()
@@ -168,7 +168,7 @@ impl RepairClient {
                 let messages = vec![open_ai::RequestMessage::User {
                     content: open_ai::MessageContent::Plain(prompt.to_string()),
                 }];
-                let response = client.generate(model, max_tokens, messages).await?;
+                let response = client.generate(model, max_tokens, messages, None).await?;
                 Ok(response.map(|r| {
                     r.choices
                         .into_iter()