From 0c47984a1940ff7dab1183651788fff0e6b7eb95 Mon Sep 17 00:00:00 2001 From: Michael Benfield Date: Sun, 14 Dec 2025 22:55:41 -0800 Subject: [PATCH] New evals for inline assistant (#44431) Also factor out some common code in the evals. Release Notes: - N/A --------- Co-authored-by: Mikayla Maki --- crates/agent/src/edit_agent/evals.rs | 1 + crates/agent_ui/Cargo.toml | 5 +- crates/agent_ui/src/agent_ui.rs | 2 - crates/agent_ui/src/buffer_codegen.rs | 184 +++++++----- crates/agent_ui/src/evals.rs | 89 ------ crates/agent_ui/src/inline_assistant.rs | 283 +++++++++++++++--- crates/agent_ui/src/inline_prompt_editor.rs | 12 +- crates/eval_utils/src/eval_utils.rs | 18 ++ crates/feature_flags/src/flags.rs | 2 +- .../language_models/src/provider/mistral.rs | 21 +- 10 files changed, 394 insertions(+), 223 deletions(-) delete mode 100644 crates/agent_ui/src/evals.rs diff --git a/crates/agent/src/edit_agent/evals.rs b/crates/agent/src/edit_agent/evals.rs index edf8a0f671d231b3bfbd29526c256388fd41f85a..01c81e0103a2d3624c7e8eb9b9c587726fcc4876 100644 --- a/crates/agent/src/edit_agent/evals.rs +++ b/crates/agent/src/edit_agent/evals.rs @@ -1343,6 +1343,7 @@ fn run_eval(eval: EvalInput) -> eval_utils::EvalOutput { let test = EditAgentTest::new(&mut cx).await; test.eval(eval, &mut cx).await }); + cx.quit(); match result { Ok(output) => eval_utils::EvalOutput { data: output.to_string(), diff --git a/crates/agent_ui/Cargo.toml b/crates/agent_ui/Cargo.toml index b235799635ce81b02fd6fcd5d4d7a53a6957eb77..38580b4d2c61597718d9fb718a20e52e84222481 100644 --- a/crates/agent_ui/Cargo.toml +++ b/crates/agent_ui/Cargo.toml @@ -13,7 +13,7 @@ path = "src/agent_ui.rs" doctest = false [features] -test-support = ["gpui/test-support", "language/test-support", "reqwest_client"] +test-support = ["assistant_text_thread/test-support", "eval_utils", "gpui/test-support", "language/test-support", "reqwest_client", "workspace/test-support"] unit-eval = [] [dependencies] @@ -40,6 +40,7 @@ component.workspace = true context_server.workspace = true db.workspace = true editor.workspace = true +eval_utils = { workspace = true, optional = true } extension.workspace = true extension_host.workspace = true feature_flags.workspace = true @@ -71,6 +72,7 @@ postage.workspace = true project.workspace = true prompt_store.workspace = true proto.workspace = true +rand.workspace = true release_channel.workspace = true rope.workspace = true rules_library.workspace = true @@ -119,7 +121,6 @@ language_model = { workspace = true, "features" = ["test-support"] } pretty_assertions.workspace = true project = { workspace = true, features = ["test-support"] } semver.workspace = true -rand.workspace = true reqwest_client.workspace = true tree-sitter-md.workspace = true unindent.workspace = true diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index cd6113bfa6c611c8d2a6b9d43294e77737b7a9ae..91fccc5fca0221cc72b0972801bf4da382cedee8 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -7,8 +7,6 @@ mod buffer_codegen; mod completion_provider; mod context; mod context_server_configuration; -#[cfg(test)] -mod evals; mod inline_assistant; mod inline_prompt_editor; mod language_model_selector; diff --git a/crates/agent_ui/src/buffer_codegen.rs b/crates/agent_ui/src/buffer_codegen.rs index bb05d5e04deb06f82dfc8e5dae0d871648f1d11e..235aea092686e669c029e8c9c7741500c23d14cb 100644 --- a/crates/agent_ui/src/buffer_codegen.rs +++ b/crates/agent_ui/src/buffer_codegen.rs @@ -41,7 +41,6 @@ use std::{ time::Instant, }; use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff}; -use ui::SharedString; /// Use this tool to provide a message to the user when you're unable to complete a task. #[derive(Debug, Serialize, Deserialize, JsonSchema)] @@ -56,16 +55,16 @@ pub struct FailureMessageInput { /// Replaces text in tags with your replacement_text. #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct RewriteSectionInput { + /// The text to replace the section with. + #[serde(default)] + pub replacement_text: String, + /// A brief description of the edit you have made. /// /// The description may use markdown formatting if you wish. /// This is optional - if the edit is simple or obvious, you should leave it empty. #[serde(default)] pub description: String, - - /// The text to replace the section with. - #[serde(default)] - pub replacement_text: String, } pub struct BufferCodegen { @@ -287,8 +286,9 @@ pub struct CodegenAlternative { completion: Option, selected_text: Option, pub message_id: Option, - pub model_explanation: Option, session_id: Uuid, + pub description: Option, + pub failure: Option, } impl EventEmitter for CodegenAlternative {} @@ -346,8 +346,9 @@ impl CodegenAlternative { elapsed_time: None, completion: None, selected_text: None, - model_explanation: None, session_id, + description: None, + failure: None, _subscription: cx.subscribe(&buffer, Self::handle_buffer_event), } } @@ -920,6 +921,16 @@ impl CodegenAlternative { self.completion.clone() } + #[cfg(any(test, feature = "test-support"))] + pub fn current_description(&self) -> Option { + self.description.clone() + } + + #[cfg(any(test, feature = "test-support"))] + pub fn current_failure(&self) -> Option { + self.failure.clone() + } + pub fn selected_text(&self) -> Option<&str> { self.selected_text.as_deref() } @@ -1133,32 +1144,69 @@ impl CodegenAlternative { } }; + enum ToolUseOutput { + Rewrite { + text: String, + description: Option, + }, + Failure(String), + } + + enum ModelUpdate { + Description(String), + Failure(String), + } + let chars_read_so_far = Arc::new(Mutex::new(0usize)); - let tool_to_text_and_message = - move |tool_use: LanguageModelToolUse| -> (Option, Option) { - let mut chars_read_so_far = chars_read_so_far.lock(); - match tool_use.name.as_ref() { - "rewrite_section" => { - let Ok(mut input) = - serde_json::from_value::(tool_use.input) - else { - return (None, None); - }; - let value = input.replacement_text[*chars_read_so_far..].to_string(); - *chars_read_so_far = input.replacement_text.len(); - (Some(value), Some(std::mem::take(&mut input.description))) - } - "failure_message" => { - let Ok(mut input) = - serde_json::from_value::(tool_use.input) - else { - return (None, None); - }; - (None, Some(std::mem::take(&mut input.message))) + let process_tool_use = move |tool_use: LanguageModelToolUse| -> Option { + let mut chars_read_so_far = chars_read_so_far.lock(); + let is_complete = tool_use.is_input_complete; + match tool_use.name.as_ref() { + "rewrite_section" => { + let Ok(mut input) = + serde_json::from_value::(tool_use.input) + else { + return None; + }; + let text = input.replacement_text[*chars_read_so_far..].to_string(); + *chars_read_so_far = input.replacement_text.len(); + let description = is_complete + .then(|| { + let desc = std::mem::take(&mut input.description); + if desc.is_empty() { None } else { Some(desc) } + }) + .flatten(); + Some(ToolUseOutput::Rewrite { text, description }) + } + "failure_message" => { + if !is_complete { + return None; } - _ => (None, None), + let Ok(mut input) = + serde_json::from_value::(tool_use.input) + else { + return None; + }; + Some(ToolUseOutput::Failure(std::mem::take(&mut input.message))) } - }; + _ => None, + } + }; + + let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded::(); + + cx.spawn({ + let codegen = codegen.clone(); + async move |cx| { + while let Some(update) = message_rx.next().await { + let _ = codegen.update(cx, |this, _cx| match update { + ModelUpdate::Description(d) => this.description = Some(d), + ModelUpdate::Failure(f) => this.failure = Some(f), + }); + } + } + }) + .detach(); let mut message_id = None; let mut first_text = None; @@ -1171,24 +1219,23 @@ impl CodegenAlternative { Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => { message_id = Some(id); } - Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) - if matches!( - tool_use.name.as_ref(), - "rewrite_section" | "failure_message" - ) => - { - let is_complete = tool_use.is_input_complete; - let (text, message) = tool_to_text_and_message(tool_use); - // Only update the model explanation if the tool use is complete. - // Otherwise the UI element bounces around as it's updated. - if is_complete { - let _ = codegen.update(cx, |this, _cx| { - this.model_explanation = message.map(Into::into); - }); - } - first_text = text; - if first_text.is_some() { - break; + Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => { + if let Some(output) = process_tool_use(tool_use) { + let (text, update) = match output { + ToolUseOutput::Rewrite { text, description } => { + (Some(text), description.map(ModelUpdate::Description)) + } + ToolUseOutput::Failure(message) => { + (None, Some(ModelUpdate::Failure(message))) + } + }; + if let Some(update) = update { + let _ = message_tx.unbounded_send(update); + } + first_text = text; + if first_text.is_some() { + break; + } } } Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => { @@ -1215,41 +1262,30 @@ impl CodegenAlternative { return; }; - let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded(); - - cx.spawn({ - let codegen = codegen.clone(); - async move |cx| { - while let Some(message) = message_rx.next().await { - let _ = codegen.update(cx, |this, _cx| { - this.model_explanation = message; - }); - } - } - }) - .detach(); - let move_last_token_usage = last_token_usage.clone(); let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain( completion_events.filter_map(move |e| { - let tool_to_text_and_message = tool_to_text_and_message.clone(); + let process_tool_use = process_tool_use.clone(); let last_token_usage = move_last_token_usage.clone(); let total_text = total_text.clone(); let mut message_tx = message_tx.clone(); async move { match e { - Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) - if matches!( - tool_use.name.as_ref(), - "rewrite_section" | "failure_message" - ) => - { - let is_complete = tool_use.is_input_complete; - let (text, message) = tool_to_text_and_message(tool_use); - if is_complete { - // Again only send the message when complete to not get a bouncing UI element. - let _ = message_tx.send(message.map(Into::into)).await; + Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => { + let Some(output) = process_tool_use(tool_use) else { + return None; + }; + let (text, update) = match output { + ToolUseOutput::Rewrite { text, description } => { + (Some(text), description.map(ModelUpdate::Description)) + } + ToolUseOutput::Failure(message) => { + (None, Some(ModelUpdate::Failure(message))) + } + }; + if let Some(update) = update { + let _ = message_tx.send(update).await; } text.map(Ok) } diff --git a/crates/agent_ui/src/evals.rs b/crates/agent_ui/src/evals.rs deleted file mode 100644 index e82d21bd1fdb02a666c61bdf4754f27e79f92fda..0000000000000000000000000000000000000000 --- a/crates/agent_ui/src/evals.rs +++ /dev/null @@ -1,89 +0,0 @@ -use std::str::FromStr; - -use crate::inline_assistant::test::run_inline_assistant_test; - -use eval_utils::{EvalOutput, NoProcessor}; -use gpui::TestAppContext; -use language_model::{LanguageModelRegistry, SelectedModel}; -use rand::{SeedableRng as _, rngs::StdRng}; - -#[test] -#[cfg_attr(not(feature = "unit-eval"), ignore)] -fn eval_single_cursor_edit() { - eval_utils::eval(20, 1.0, NoProcessor, move || { - run_eval( - &EvalInput { - prompt: "Rename this variable to buffer_text".to_string(), - buffer: indoc::indoc! {" - struct EvalExampleStruct { - text: Strˇing, - prompt: String, - } - "} - .to_string(), - }, - &|_, output| { - let expected = indoc::indoc! {" - struct EvalExampleStruct { - buffer_text: String, - prompt: String, - } - "}; - if output == expected { - EvalOutput { - outcome: eval_utils::OutcomeKind::Passed, - data: "Passed!".to_string(), - metadata: (), - } - } else { - EvalOutput { - outcome: eval_utils::OutcomeKind::Failed, - data: format!("Failed to rename variable, output: {}", output), - metadata: (), - } - } - }, - ) - }); -} - -struct EvalInput { - buffer: String, - prompt: String, -} - -fn run_eval( - input: &EvalInput, - judge: &dyn Fn(&EvalInput, &str) -> eval_utils::EvalOutput<()>, -) -> eval_utils::EvalOutput<()> { - let dispatcher = gpui::TestDispatcher::new(StdRng::from_os_rng()); - let mut cx = TestAppContext::build(dispatcher, None); - cx.skip_drawing(); - - let buffer_text = run_inline_assistant_test( - input.buffer.clone(), - input.prompt.clone(), - |cx| { - // Reconfigure to use a real model instead of the fake one - let model_name = std::env::var("ZED_AGENT_MODEL") - .unwrap_or("anthropic/claude-sonnet-4-latest".into()); - - let selected_model = SelectedModel::from_str(&model_name) - .expect("Invalid model format. Use 'provider/model-id'"); - - log::info!("Selected model: {selected_model:?}"); - - cx.update(|_, cx| { - LanguageModelRegistry::global(cx).update(cx, |registry, cx| { - registry.select_inline_assistant_model(Some(&selected_model), cx); - }); - }); - }, - |_cx| { - log::info!("Waiting for actual response from the LLM..."); - }, - &mut cx, - ); - - judge(input, &buffer_text) -} diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs index 0eb96b3712623cc08632ede6c7836ed09499c02d..d036032e77d74dd905001affd9aba0010bc4f8eb 100644 --- a/crates/agent_ui/src/inline_assistant.rs +++ b/crates/agent_ui/src/inline_assistant.rs @@ -117,14 +117,6 @@ impl InlineAssistant { } } - #[cfg(any(test, feature = "test-support"))] - pub fn set_completion_receiver( - &mut self, - sender: mpsc::UnboundedSender>, - ) { - self._inline_assistant_completions = Some(sender); - } - pub fn register_workspace( &mut self, workspace: &Entity, @@ -1593,6 +1585,27 @@ impl InlineAssistant { .map(InlineAssistTarget::Terminal) } } + + #[cfg(any(test, feature = "test-support"))] + pub fn set_completion_receiver( + &mut self, + sender: mpsc::UnboundedSender>, + ) { + self._inline_assistant_completions = Some(sender); + } + + #[cfg(any(test, feature = "test-support"))] + pub fn get_codegen( + &mut self, + assist_id: InlineAssistId, + cx: &mut App, + ) -> Option> { + self.assists.get(&assist_id).map(|inline_assist| { + inline_assist + .codegen + .update(cx, |codegen, _cx| codegen.active_alternative().clone()) + }) + } } struct EditorInlineAssists { @@ -2014,8 +2027,10 @@ fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { } } -#[cfg(any(test, feature = "test-support"))] +#[cfg(any(test, feature = "unit-eval"))] +#[cfg_attr(not(test), allow(dead_code))] pub mod test { + use std::sync::Arc; use agent::HistoryStore; @@ -2026,7 +2041,6 @@ pub mod test { use futures::channel::mpsc; use gpui::{AppContext, TestAppContext, UpdateGlobal as _}; use language::Buffer; - use language_model::LanguageModelRegistry; use project::Project; use prompt_store::PromptBuilder; use smol::stream::StreamExt as _; @@ -2035,13 +2049,43 @@ pub mod test { use crate::InlineAssistant; + #[derive(Debug)] + pub enum InlineAssistantOutput { + Success { + completion: Option, + description: Option, + full_buffer_text: String, + }, + Failure { + failure: String, + }, + // These fields are used for logging + #[allow(unused)] + Malformed { + completion: Option, + description: Option, + failure: Option, + }, + } + + impl InlineAssistantOutput { + pub fn buffer_text(&self) -> &str { + match self { + InlineAssistantOutput::Success { + full_buffer_text, .. + } => full_buffer_text, + _ => "", + } + } + } + pub fn run_inline_assistant_test( base_buffer: String, prompt: String, setup: SetupF, test: TestF, cx: &mut TestAppContext, - ) -> String + ) -> InlineAssistantOutput where SetupF: FnOnce(&mut gpui::VisualTestContext), TestF: FnOnce(&mut gpui::VisualTestContext), @@ -2133,39 +2177,198 @@ pub mod test { test(cx); - cx.executor() - .block_test(async { completion_rx.next().await }); + let assist_id = cx + .executor() + .block_test(async { completion_rx.next().await }) + .unwrap() + .unwrap(); + + let (completion, description, failure) = cx.update(|_, cx| { + InlineAssistant::update_global(cx, |inline_assistant, cx| { + let codegen = inline_assistant.get_codegen(assist_id, cx).unwrap(); + + let completion = codegen.read(cx).current_completion(); + let description = codegen.read(cx).current_description(); + let failure = codegen.read(cx).current_failure(); - buffer.read_with(cx, |buffer, _| buffer.text()) + (completion, description, failure) + }) + }); + + if failure.is_some() && (completion.is_some() || description.is_some()) { + InlineAssistantOutput::Malformed { + completion, + description, + failure, + } + } else if let Some(failure) = failure { + InlineAssistantOutput::Failure { failure } + } else { + InlineAssistantOutput::Success { + completion, + description, + full_buffer_text: buffer.read_with(cx, |buffer, _| buffer.text()), + } + } } +} - #[allow(unused)] - pub fn test_inline_assistant( - base_buffer: &'static str, - llm_output: &'static str, - cx: &mut TestAppContext, - ) -> String { - run_inline_assistant_test( - base_buffer.to_string(), - "Prompt doesn't matter because we're using a fake model".to_string(), - |cx| { - cx.update(|_, cx| LanguageModelRegistry::test(cx)); - }, - |cx| { - let fake_model = cx.update(|_, cx| { - LanguageModelRegistry::global(cx) - .update(cx, |registry, _| registry.fake_model()) - }); - let fake = fake_model.as_fake(); +#[cfg(any(test, feature = "unit-eval"))] +#[cfg_attr(not(test), allow(dead_code))] +pub mod evals { + use std::str::FromStr; + + use eval_utils::{EvalOutput, NoProcessor}; + use gpui::TestAppContext; + use language_model::{LanguageModelRegistry, SelectedModel}; + use rand::{SeedableRng as _, rngs::StdRng}; + + use crate::inline_assistant::test::{InlineAssistantOutput, run_inline_assistant_test}; + + #[test] + #[cfg_attr(not(feature = "unit-eval"), ignore)] + fn eval_single_cursor_edit() { + run_eval( + 20, + 1.0, + "Rename this variable to buffer_text".to_string(), + indoc::indoc! {" + struct EvalExampleStruct { + text: Strˇing, + prompt: String, + } + "} + .to_string(), + exact_buffer_match(indoc::indoc! {" + struct EvalExampleStruct { + buffer_text: String, + prompt: String, + } + "}), + ); + } - // let fake = fake_model; - fake.send_last_completion_stream_text_chunk(llm_output.to_string()); - fake.end_last_completion_stream(); + #[test] + #[cfg_attr(not(feature = "unit-eval"), ignore)] + fn eval_cant_do() { + run_eval( + 20, + 1.0, + "Rename the struct to EvalExampleStructNope", + indoc::indoc! {" + struct EvalExampleStruct { + text: Strˇing, + prompt: String, + } + "}, + uncertain_output, + ); + } - // Run again to process the model's response - cx.run_until_parked(); - }, - cx, - ) + #[test] + #[cfg_attr(not(feature = "unit-eval"), ignore)] + fn eval_unclear() { + run_eval( + 20, + 1.0, + "Make exactly the change I want you to make", + indoc::indoc! {" + struct EvalExampleStruct { + text: Strˇing, + prompt: String, + } + "}, + uncertain_output, + ); + } + + fn run_eval( + iterations: usize, + expected_pass_ratio: f32, + prompt: impl Into, + buffer: impl Into, + judge: impl Fn(InlineAssistantOutput) -> eval_utils::EvalOutput<()> + Send + Sync + 'static, + ) { + let buffer = buffer.into(); + let prompt = prompt.into(); + + eval_utils::eval(iterations, expected_pass_ratio, NoProcessor, move || { + let dispatcher = gpui::TestDispatcher::new(StdRng::from_os_rng()); + let mut cx = TestAppContext::build(dispatcher, None); + cx.skip_drawing(); + + let output = run_inline_assistant_test( + buffer.clone(), + prompt.clone(), + |cx| { + // Reconfigure to use a real model instead of the fake one + let model_name = std::env::var("ZED_AGENT_MODEL") + .unwrap_or("anthropic/claude-sonnet-4-latest".into()); + + let selected_model = SelectedModel::from_str(&model_name) + .expect("Invalid model format. Use 'provider/model-id'"); + + log::info!("Selected model: {selected_model:?}"); + + cx.update(|_, cx| { + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry.select_inline_assistant_model(Some(&selected_model), cx); + }); + }); + }, + |_cx| { + log::info!("Waiting for actual response from the LLM..."); + }, + &mut cx, + ); + + cx.quit(); + + judge(output) + }); + } + + fn uncertain_output(output: InlineAssistantOutput) -> EvalOutput<()> { + match &output { + o @ InlineAssistantOutput::Success { + completion, + description, + .. + } => { + if description.is_some() && completion.is_none() { + EvalOutput::passed(format!( + "Assistant produced no completion, but a description:\n{}", + description.as_ref().unwrap() + )) + } else { + EvalOutput::failed(format!("Assistant produced a completion:\n{:?}", o)) + } + } + InlineAssistantOutput::Failure { + failure: error_message, + } => EvalOutput::passed(format!( + "Assistant produced a failure message: {}", + error_message + )), + o @ InlineAssistantOutput::Malformed { .. } => { + EvalOutput::failed(format!("Assistant produced a malformed response:\n{:?}", o)) + } + } + } + + fn exact_buffer_match( + correct_output: impl Into, + ) -> impl Fn(InlineAssistantOutput) -> EvalOutput<()> { + let correct_output = correct_output.into(); + move |output| { + if output.buffer_text() == correct_output { + EvalOutput::passed("Assistant output matches") + } else { + EvalOutput::failed(format!( + "Assistant output does not match expected output: {:?}", + output + )) + } + } } } diff --git a/crates/agent_ui/src/inline_prompt_editor.rs b/crates/agent_ui/src/inline_prompt_editor.rs index e262cda87899b0314c9fd8909f5718b4fd7dbfda..278216e28ec6304a9fc596c8456921fb1f1ebdfd 100644 --- a/crates/agent_ui/src/inline_prompt_editor.rs +++ b/crates/agent_ui/src/inline_prompt_editor.rs @@ -101,11 +101,11 @@ impl Render for PromptEditor { let left_gutter_width = gutter.full_width() + (gutter.margin / 2.0); let right_padding = editor_margins.right + RIGHT_PADDING; - let explanation = codegen - .active_alternative() - .read(cx) - .model_explanation - .clone(); + let active_alternative = codegen.active_alternative().read(cx); + let explanation = active_alternative + .description + .clone() + .or_else(|| active_alternative.failure.clone()); (left_gutter_width, right_padding, explanation) } @@ -139,7 +139,7 @@ impl Render for PromptEditor { if let Some(explanation) = &explanation { markdown.update(cx, |markdown, cx| { - markdown.reset(explanation.clone(), cx); + markdown.reset(SharedString::from(explanation), cx); }); } diff --git a/crates/eval_utils/src/eval_utils.rs b/crates/eval_utils/src/eval_utils.rs index 880b1a97e414bbc3219bdf8f7163dbf9b6c9c82b..be3294ed1490d6a602c3a5282d25dbba7d065443 100644 --- a/crates/eval_utils/src/eval_utils.rs +++ b/crates/eval_utils/src/eval_utils.rs @@ -40,6 +40,24 @@ pub struct EvalOutput { pub metadata: M, } +impl EvalOutput { + pub fn passed(message: impl Into) -> Self { + EvalOutput { + outcome: OutcomeKind::Passed, + data: message.into(), + metadata: M::default(), + } + } + + pub fn failed(message: impl Into) -> Self { + EvalOutput { + outcome: OutcomeKind::Failed, + data: message.into(), + metadata: M::default(), + } + } +} + pub struct NoProcessor; impl EvalOutputProcessor for NoProcessor { type Metadata = (); diff --git a/crates/feature_flags/src/flags.rs b/crates/feature_flags/src/flags.rs index 566d5604149567702e8739d2f3ac9fdc6f5f0de8..0d474878f999bc773baff7664ca0305c2031c171 100644 --- a/crates/feature_flags/src/flags.rs +++ b/crates/feature_flags/src/flags.rs @@ -18,6 +18,6 @@ impl FeatureFlag for InlineAssistantUseToolFeatureFlag { const NAME: &'static str = "inline-assistant-use-tool"; fn enabled_for_staff() -> bool { - false + true } } diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 1078e2d7f7841d7ad05284e10a9f862236966ebc..3e99f32be8224bb2b9973feccb0ce973b58eaaed 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -17,7 +17,7 @@ use settings::{Settings, SettingsStore}; use std::collections::HashMap; use std::pin::Pin; use std::str::FromStr; -use std::sync::{Arc, LazyLock, OnceLock}; +use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; @@ -31,7 +31,6 @@ static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); const CODESTRAL_API_KEY_ENV_VAR_NAME: &str = "CODESTRAL_API_KEY"; static CODESTRAL_API_KEY_ENV_VAR: LazyLock = env_var!(CODESTRAL_API_KEY_ENV_VAR_NAME); -static CODESTRAL_API_KEY: OnceLock> = OnceLock::new(); #[derive(Default, Clone, Debug, PartialEq)] pub struct MistralSettings { @@ -49,14 +48,18 @@ pub struct State { codestral_api_key_state: Entity, } +struct CodestralApiKey(Entity); +impl Global for CodestralApiKey {} + pub fn codestral_api_key(cx: &mut App) -> Entity { - return CODESTRAL_API_KEY - .get_or_init(|| { - cx.new(|_| { - ApiKeyState::new(CODESTRAL_API_URL.into(), CODESTRAL_API_KEY_ENV_VAR.clone()) - }) - }) - .clone(); + if cx.has_global::() { + cx.global::().0.clone() + } else { + let api_key_state = cx + .new(|_| ApiKeyState::new(CODESTRAL_API_URL.into(), CODESTRAL_API_KEY_ENV_VAR.clone())); + cx.set_global(CodestralApiKey(api_key_state.clone())); + api_key_state + } } impl State {