Detailed changes
@@ -1343,6 +1343,7 @@ fn run_eval(eval: EvalInput) -> eval_utils::EvalOutput<EditEvalMetadata> {
let test = EditAgentTest::new(&mut cx).await;
test.eval(eval, &mut cx).await
});
+ cx.quit();
match result {
Ok(output) => eval_utils::EvalOutput {
data: output.to_string(),
@@ -13,7 +13,7 @@ path = "src/agent_ui.rs"
doctest = false
[features]
-test-support = ["gpui/test-support", "language/test-support", "reqwest_client"]
+test-support = ["assistant_text_thread/test-support", "eval_utils", "gpui/test-support", "language/test-support", "reqwest_client", "workspace/test-support"]
unit-eval = []
[dependencies]
@@ -40,6 +40,7 @@ component.workspace = true
context_server.workspace = true
db.workspace = true
editor.workspace = true
+eval_utils = { workspace = true, optional = true }
extension.workspace = true
extension_host.workspace = true
feature_flags.workspace = true
@@ -71,6 +72,7 @@ postage.workspace = true
project.workspace = true
prompt_store.workspace = true
proto.workspace = true
+rand.workspace = true
release_channel.workspace = true
rope.workspace = true
rules_library.workspace = true
@@ -119,7 +121,6 @@ language_model = { workspace = true, "features" = ["test-support"] }
pretty_assertions.workspace = true
project = { workspace = true, features = ["test-support"] }
semver.workspace = true
-rand.workspace = true
reqwest_client.workspace = true
tree-sitter-md.workspace = true
unindent.workspace = true
@@ -7,8 +7,6 @@ mod buffer_codegen;
mod completion_provider;
mod context;
mod context_server_configuration;
-#[cfg(test)]
-mod evals;
mod inline_assistant;
mod inline_prompt_editor;
mod language_model_selector;
@@ -41,7 +41,6 @@ use std::{
time::Instant,
};
use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
-use ui::SharedString;
/// Use this tool to provide a message to the user when you're unable to complete a task.
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
@@ -56,16 +55,16 @@ pub struct FailureMessageInput {
/// Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct RewriteSectionInput {
+ /// The text to replace the section with.
+ #[serde(default)]
+ pub replacement_text: String,
+
/// A brief description of the edit you have made.
///
/// The description may use markdown formatting if you wish.
/// This is optional - if the edit is simple or obvious, you should leave it empty.
#[serde(default)]
pub description: String,
-
- /// The text to replace the section with.
- #[serde(default)]
- pub replacement_text: String,
}
pub struct BufferCodegen {
@@ -287,8 +286,9 @@ pub struct CodegenAlternative {
completion: Option<String>,
selected_text: Option<String>,
pub message_id: Option<String>,
- pub model_explanation: Option<SharedString>,
session_id: Uuid,
+ pub description: Option<String>,
+ pub failure: Option<String>,
}
impl EventEmitter<CodegenEvent> for CodegenAlternative {}
@@ -346,8 +346,9 @@ impl CodegenAlternative {
elapsed_time: None,
completion: None,
selected_text: None,
- model_explanation: None,
session_id,
+ description: None,
+ failure: None,
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
}
}
@@ -920,6 +921,16 @@ impl CodegenAlternative {
self.completion.clone()
}
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn current_description(&self) -> Option<String> {
+ self.description.clone()
+ }
+
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn current_failure(&self) -> Option<String> {
+ self.failure.clone()
+ }
+
pub fn selected_text(&self) -> Option<&str> {
self.selected_text.as_deref()
}
@@ -1133,32 +1144,69 @@ impl CodegenAlternative {
}
};
+ enum ToolUseOutput {
+ Rewrite {
+ text: String,
+ description: Option<String>,
+ },
+ Failure(String),
+ }
+
+ enum ModelUpdate {
+ Description(String),
+ Failure(String),
+ }
+
let chars_read_so_far = Arc::new(Mutex::new(0usize));
- let tool_to_text_and_message =
- move |tool_use: LanguageModelToolUse| -> (Option<String>, Option<String>) {
- let mut chars_read_so_far = chars_read_so_far.lock();
- match tool_use.name.as_ref() {
- "rewrite_section" => {
- let Ok(mut input) =
- serde_json::from_value::<RewriteSectionInput>(tool_use.input)
- else {
- return (None, None);
- };
- let value = input.replacement_text[*chars_read_so_far..].to_string();
- *chars_read_so_far = input.replacement_text.len();
- (Some(value), Some(std::mem::take(&mut input.description)))
- }
- "failure_message" => {
- let Ok(mut input) =
- serde_json::from_value::<FailureMessageInput>(tool_use.input)
- else {
- return (None, None);
- };
- (None, Some(std::mem::take(&mut input.message)))
+ let process_tool_use = move |tool_use: LanguageModelToolUse| -> Option<ToolUseOutput> {
+ let mut chars_read_so_far = chars_read_so_far.lock();
+ let is_complete = tool_use.is_input_complete;
+ match tool_use.name.as_ref() {
+ "rewrite_section" => {
+ let Ok(mut input) =
+ serde_json::from_value::<RewriteSectionInput>(tool_use.input)
+ else {
+ return None;
+ };
+ let text = input.replacement_text[*chars_read_so_far..].to_string();
+ *chars_read_so_far = input.replacement_text.len();
+ let description = is_complete
+ .then(|| {
+ let desc = std::mem::take(&mut input.description);
+ if desc.is_empty() { None } else { Some(desc) }
+ })
+ .flatten();
+ Some(ToolUseOutput::Rewrite { text, description })
+ }
+ "failure_message" => {
+ if !is_complete {
+ return None;
}
- _ => (None, None),
+ let Ok(mut input) =
+ serde_json::from_value::<FailureMessageInput>(tool_use.input)
+ else {
+ return None;
+ };
+ Some(ToolUseOutput::Failure(std::mem::take(&mut input.message)))
}
- };
+ _ => None,
+ }
+ };
+
+ let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded::<ModelUpdate>();
+
+ cx.spawn({
+ let codegen = codegen.clone();
+ async move |cx| {
+ while let Some(update) = message_rx.next().await {
+ let _ = codegen.update(cx, |this, _cx| match update {
+ ModelUpdate::Description(d) => this.description = Some(d),
+ ModelUpdate::Failure(f) => this.failure = Some(f),
+ });
+ }
+ }
+ })
+ .detach();
let mut message_id = None;
let mut first_text = None;
@@ -1171,24 +1219,23 @@ impl CodegenAlternative {
Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
message_id = Some(id);
}
- Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
- if matches!(
- tool_use.name.as_ref(),
- "rewrite_section" | "failure_message"
- ) =>
- {
- let is_complete = tool_use.is_input_complete;
- let (text, message) = tool_to_text_and_message(tool_use);
- // Only update the model explanation if the tool use is complete.
- // Otherwise the UI element bounces around as it's updated.
- if is_complete {
- let _ = codegen.update(cx, |this, _cx| {
- this.model_explanation = message.map(Into::into);
- });
- }
- first_text = text;
- if first_text.is_some() {
- break;
+ Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
+ if let Some(output) = process_tool_use(tool_use) {
+ let (text, update) = match output {
+ ToolUseOutput::Rewrite { text, description } => {
+ (Some(text), description.map(ModelUpdate::Description))
+ }
+ ToolUseOutput::Failure(message) => {
+ (None, Some(ModelUpdate::Failure(message)))
+ }
+ };
+ if let Some(update) = update {
+ let _ = message_tx.unbounded_send(update);
+ }
+ first_text = text;
+ if first_text.is_some() {
+ break;
+ }
}
}
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
@@ -1215,41 +1262,30 @@ impl CodegenAlternative {
return;
};
- let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded();
-
- cx.spawn({
- let codegen = codegen.clone();
- async move |cx| {
- while let Some(message) = message_rx.next().await {
- let _ = codegen.update(cx, |this, _cx| {
- this.model_explanation = message;
- });
- }
- }
- })
- .detach();
-
let move_last_token_usage = last_token_usage.clone();
let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain(
completion_events.filter_map(move |e| {
- let tool_to_text_and_message = tool_to_text_and_message.clone();
+ let process_tool_use = process_tool_use.clone();
let last_token_usage = move_last_token_usage.clone();
let total_text = total_text.clone();
let mut message_tx = message_tx.clone();
async move {
match e {
- Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
- if matches!(
- tool_use.name.as_ref(),
- "rewrite_section" | "failure_message"
- ) =>
- {
- let is_complete = tool_use.is_input_complete;
- let (text, message) = tool_to_text_and_message(tool_use);
- if is_complete {
- // Again only send the message when complete to not get a bouncing UI element.
- let _ = message_tx.send(message.map(Into::into)).await;
+ Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
+ let Some(output) = process_tool_use(tool_use) else {
+ return None;
+ };
+ let (text, update) = match output {
+ ToolUseOutput::Rewrite { text, description } => {
+ (Some(text), description.map(ModelUpdate::Description))
+ }
+ ToolUseOutput::Failure(message) => {
+ (None, Some(ModelUpdate::Failure(message)))
+ }
+ };
+ if let Some(update) = update {
+ let _ = message_tx.send(update).await;
}
text.map(Ok)
}
@@ -1,89 +0,0 @@
-use std::str::FromStr;
-
-use crate::inline_assistant::test::run_inline_assistant_test;
-
-use eval_utils::{EvalOutput, NoProcessor};
-use gpui::TestAppContext;
-use language_model::{LanguageModelRegistry, SelectedModel};
-use rand::{SeedableRng as _, rngs::StdRng};
-
-#[test]
-#[cfg_attr(not(feature = "unit-eval"), ignore)]
-fn eval_single_cursor_edit() {
- eval_utils::eval(20, 1.0, NoProcessor, move || {
- run_eval(
- &EvalInput {
- prompt: "Rename this variable to buffer_text".to_string(),
- buffer: indoc::indoc! {"
- struct EvalExampleStruct {
- text: Strหing,
- prompt: String,
- }
- "}
- .to_string(),
- },
- &|_, output| {
- let expected = indoc::indoc! {"
- struct EvalExampleStruct {
- buffer_text: String,
- prompt: String,
- }
- "};
- if output == expected {
- EvalOutput {
- outcome: eval_utils::OutcomeKind::Passed,
- data: "Passed!".to_string(),
- metadata: (),
- }
- } else {
- EvalOutput {
- outcome: eval_utils::OutcomeKind::Failed,
- data: format!("Failed to rename variable, output: {}", output),
- metadata: (),
- }
- }
- },
- )
- });
-}
-
-struct EvalInput {
- buffer: String,
- prompt: String,
-}
-
-fn run_eval(
- input: &EvalInput,
- judge: &dyn Fn(&EvalInput, &str) -> eval_utils::EvalOutput<()>,
-) -> eval_utils::EvalOutput<()> {
- let dispatcher = gpui::TestDispatcher::new(StdRng::from_os_rng());
- let mut cx = TestAppContext::build(dispatcher, None);
- cx.skip_drawing();
-
- let buffer_text = run_inline_assistant_test(
- input.buffer.clone(),
- input.prompt.clone(),
- |cx| {
- // Reconfigure to use a real model instead of the fake one
- let model_name = std::env::var("ZED_AGENT_MODEL")
- .unwrap_or("anthropic/claude-sonnet-4-latest".into());
-
- let selected_model = SelectedModel::from_str(&model_name)
- .expect("Invalid model format. Use 'provider/model-id'");
-
- log::info!("Selected model: {selected_model:?}");
-
- cx.update(|_, cx| {
- LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
- registry.select_inline_assistant_model(Some(&selected_model), cx);
- });
- });
- },
- |_cx| {
- log::info!("Waiting for actual response from the LLM...");
- },
- &mut cx,
- );
-
- judge(input, &buffer_text)
-}
@@ -117,14 +117,6 @@ impl InlineAssistant {
}
}
- #[cfg(any(test, feature = "test-support"))]
- pub fn set_completion_receiver(
- &mut self,
- sender: mpsc::UnboundedSender<anyhow::Result<InlineAssistId>>,
- ) {
- self._inline_assistant_completions = Some(sender);
- }
-
pub fn register_workspace(
&mut self,
workspace: &Entity<Workspace>,
@@ -1593,6 +1585,27 @@ impl InlineAssistant {
.map(InlineAssistTarget::Terminal)
}
}
+
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn set_completion_receiver(
+ &mut self,
+ sender: mpsc::UnboundedSender<anyhow::Result<InlineAssistId>>,
+ ) {
+ self._inline_assistant_completions = Some(sender);
+ }
+
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn get_codegen(
+ &mut self,
+ assist_id: InlineAssistId,
+ cx: &mut App,
+ ) -> Option<Entity<CodegenAlternative>> {
+ self.assists.get(&assist_id).map(|inline_assist| {
+ inline_assist
+ .codegen
+ .update(cx, |codegen, _cx| codegen.active_alternative().clone())
+ })
+ }
}
struct EditorInlineAssists {
@@ -2014,8 +2027,10 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
}
}
-#[cfg(any(test, feature = "test-support"))]
+#[cfg(any(test, feature = "unit-eval"))]
+#[cfg_attr(not(test), allow(dead_code))]
pub mod test {
+
use std::sync::Arc;
use agent::HistoryStore;
@@ -2026,7 +2041,6 @@ pub mod test {
use futures::channel::mpsc;
use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
use language::Buffer;
- use language_model::LanguageModelRegistry;
use project::Project;
use prompt_store::PromptBuilder;
use smol::stream::StreamExt as _;
@@ -2035,13 +2049,43 @@ pub mod test {
use crate::InlineAssistant;
+ #[derive(Debug)]
+ pub enum InlineAssistantOutput {
+ Success {
+ completion: Option<String>,
+ description: Option<String>,
+ full_buffer_text: String,
+ },
+ Failure {
+ failure: String,
+ },
+ // These fields are used for logging
+ #[allow(unused)]
+ Malformed {
+ completion: Option<String>,
+ description: Option<String>,
+ failure: Option<String>,
+ },
+ }
+
+ impl InlineAssistantOutput {
+ pub fn buffer_text(&self) -> &str {
+ match self {
+ InlineAssistantOutput::Success {
+ full_buffer_text, ..
+ } => full_buffer_text,
+ _ => "",
+ }
+ }
+ }
+
pub fn run_inline_assistant_test<SetupF, TestF>(
base_buffer: String,
prompt: String,
setup: SetupF,
test: TestF,
cx: &mut TestAppContext,
- ) -> String
+ ) -> InlineAssistantOutput
where
SetupF: FnOnce(&mut gpui::VisualTestContext),
TestF: FnOnce(&mut gpui::VisualTestContext),
@@ -2133,39 +2177,198 @@ pub mod test {
test(cx);
- cx.executor()
- .block_test(async { completion_rx.next().await });
+ let assist_id = cx
+ .executor()
+ .block_test(async { completion_rx.next().await })
+ .unwrap()
+ .unwrap();
+
+ let (completion, description, failure) = cx.update(|_, cx| {
+ InlineAssistant::update_global(cx, |inline_assistant, cx| {
+ let codegen = inline_assistant.get_codegen(assist_id, cx).unwrap();
+
+ let completion = codegen.read(cx).current_completion();
+ let description = codegen.read(cx).current_description();
+ let failure = codegen.read(cx).current_failure();
- buffer.read_with(cx, |buffer, _| buffer.text())
+ (completion, description, failure)
+ })
+ });
+
+ if failure.is_some() && (completion.is_some() || description.is_some()) {
+ InlineAssistantOutput::Malformed {
+ completion,
+ description,
+ failure,
+ }
+ } else if let Some(failure) = failure {
+ InlineAssistantOutput::Failure { failure }
+ } else {
+ InlineAssistantOutput::Success {
+ completion,
+ description,
+ full_buffer_text: buffer.read_with(cx, |buffer, _| buffer.text()),
+ }
+ }
}
+}
- #[allow(unused)]
- pub fn test_inline_assistant(
- base_buffer: &'static str,
- llm_output: &'static str,
- cx: &mut TestAppContext,
- ) -> String {
- run_inline_assistant_test(
- base_buffer.to_string(),
- "Prompt doesn't matter because we're using a fake model".to_string(),
- |cx| {
- cx.update(|_, cx| LanguageModelRegistry::test(cx));
- },
- |cx| {
- let fake_model = cx.update(|_, cx| {
- LanguageModelRegistry::global(cx)
- .update(cx, |registry, _| registry.fake_model())
- });
- let fake = fake_model.as_fake();
+#[cfg(any(test, feature = "unit-eval"))]
+#[cfg_attr(not(test), allow(dead_code))]
+pub mod evals {
+ use std::str::FromStr;
+
+ use eval_utils::{EvalOutput, NoProcessor};
+ use gpui::TestAppContext;
+ use language_model::{LanguageModelRegistry, SelectedModel};
+ use rand::{SeedableRng as _, rngs::StdRng};
+
+ use crate::inline_assistant::test::{InlineAssistantOutput, run_inline_assistant_test};
+
+ #[test]
+ #[cfg_attr(not(feature = "unit-eval"), ignore)]
+ fn eval_single_cursor_edit() {
+ run_eval(
+ 20,
+ 1.0,
+ "Rename this variable to buffer_text".to_string(),
+ indoc::indoc! {"
+ struct EvalExampleStruct {
+ text: Strหing,
+ prompt: String,
+ }
+ "}
+ .to_string(),
+ exact_buffer_match(indoc::indoc! {"
+ struct EvalExampleStruct {
+ buffer_text: String,
+ prompt: String,
+ }
+ "}),
+ );
+ }
- // let fake = fake_model;
- fake.send_last_completion_stream_text_chunk(llm_output.to_string());
- fake.end_last_completion_stream();
+ #[test]
+ #[cfg_attr(not(feature = "unit-eval"), ignore)]
+ fn eval_cant_do() {
+ run_eval(
+ 20,
+ 1.0,
+ "Rename the struct to EvalExampleStructNope",
+ indoc::indoc! {"
+ struct EvalExampleStruct {
+ text: Strหing,
+ prompt: String,
+ }
+ "},
+ uncertain_output,
+ );
+ }
- // Run again to process the model's response
- cx.run_until_parked();
- },
- cx,
- )
+ #[test]
+ #[cfg_attr(not(feature = "unit-eval"), ignore)]
+ fn eval_unclear() {
+ run_eval(
+ 20,
+ 1.0,
+ "Make exactly the change I want you to make",
+ indoc::indoc! {"
+ struct EvalExampleStruct {
+ text: Strหing,
+ prompt: String,
+ }
+ "},
+ uncertain_output,
+ );
+ }
+
+ fn run_eval(
+ iterations: usize,
+ expected_pass_ratio: f32,
+ prompt: impl Into<String>,
+ buffer: impl Into<String>,
+ judge: impl Fn(InlineAssistantOutput) -> eval_utils::EvalOutput<()> + Send + Sync + 'static,
+ ) {
+ let buffer = buffer.into();
+ let prompt = prompt.into();
+
+ eval_utils::eval(iterations, expected_pass_ratio, NoProcessor, move || {
+ let dispatcher = gpui::TestDispatcher::new(StdRng::from_os_rng());
+ let mut cx = TestAppContext::build(dispatcher, None);
+ cx.skip_drawing();
+
+ let output = run_inline_assistant_test(
+ buffer.clone(),
+ prompt.clone(),
+ |cx| {
+ // Reconfigure to use a real model instead of the fake one
+ let model_name = std::env::var("ZED_AGENT_MODEL")
+ .unwrap_or("anthropic/claude-sonnet-4-latest".into());
+
+ let selected_model = SelectedModel::from_str(&model_name)
+ .expect("Invalid model format. Use 'provider/model-id'");
+
+ log::info!("Selected model: {selected_model:?}");
+
+ cx.update(|_, cx| {
+ LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+ registry.select_inline_assistant_model(Some(&selected_model), cx);
+ });
+ });
+ },
+ |_cx| {
+ log::info!("Waiting for actual response from the LLM...");
+ },
+ &mut cx,
+ );
+
+ cx.quit();
+
+ judge(output)
+ });
+ }
+
+ fn uncertain_output(output: InlineAssistantOutput) -> EvalOutput<()> {
+ match &output {
+ o @ InlineAssistantOutput::Success {
+ completion,
+ description,
+ ..
+ } => {
+ if description.is_some() && completion.is_none() {
+ EvalOutput::passed(format!(
+ "Assistant produced no completion, but a description:\n{}",
+ description.as_ref().unwrap()
+ ))
+ } else {
+ EvalOutput::failed(format!("Assistant produced a completion:\n{:?}", o))
+ }
+ }
+ InlineAssistantOutput::Failure {
+ failure: error_message,
+ } => EvalOutput::passed(format!(
+ "Assistant produced a failure message: {}",
+ error_message
+ )),
+ o @ InlineAssistantOutput::Malformed { .. } => {
+ EvalOutput::failed(format!("Assistant produced a malformed response:\n{:?}", o))
+ }
+ }
+ }
+
+ fn exact_buffer_match(
+ correct_output: impl Into<String>,
+ ) -> impl Fn(InlineAssistantOutput) -> EvalOutput<()> {
+ let correct_output = correct_output.into();
+ move |output| {
+ if output.buffer_text() == correct_output {
+ EvalOutput::passed("Assistant output matches")
+ } else {
+ EvalOutput::failed(format!(
+ "Assistant output does not match expected output: {:?}",
+ output
+ ))
+ }
+ }
}
}
@@ -101,11 +101,11 @@ impl<T: 'static> Render for PromptEditor<T> {
let left_gutter_width = gutter.full_width() + (gutter.margin / 2.0);
let right_padding = editor_margins.right + RIGHT_PADDING;
- let explanation = codegen
- .active_alternative()
- .read(cx)
- .model_explanation
- .clone();
+ let active_alternative = codegen.active_alternative().read(cx);
+ let explanation = active_alternative
+ .description
+ .clone()
+ .or_else(|| active_alternative.failure.clone());
(left_gutter_width, right_padding, explanation)
}
@@ -139,7 +139,7 @@ impl<T: 'static> Render for PromptEditor<T> {
if let Some(explanation) = &explanation {
markdown.update(cx, |markdown, cx| {
- markdown.reset(explanation.clone(), cx);
+ markdown.reset(SharedString::from(explanation), cx);
});
}
@@ -40,6 +40,24 @@ pub struct EvalOutput<M> {
pub metadata: M,
}
+impl<M: Default> EvalOutput<M> {
+ pub fn passed(message: impl Into<String>) -> Self {
+ EvalOutput {
+ outcome: OutcomeKind::Passed,
+ data: message.into(),
+ metadata: M::default(),
+ }
+ }
+
+ pub fn failed(message: impl Into<String>) -> Self {
+ EvalOutput {
+ outcome: OutcomeKind::Failed,
+ data: message.into(),
+ metadata: M::default(),
+ }
+ }
+}
+
pub struct NoProcessor;
impl EvalOutputProcessor for NoProcessor {
type Metadata = ();
@@ -18,6 +18,6 @@ impl FeatureFlag for InlineAssistantUseToolFeatureFlag {
const NAME: &'static str = "inline-assistant-use-tool";
fn enabled_for_staff() -> bool {
- false
+ true
}
}
@@ -17,7 +17,7 @@ use settings::{Settings, SettingsStore};
use std::collections::HashMap;
use std::pin::Pin;
use std::str::FromStr;
-use std::sync::{Arc, LazyLock, OnceLock};
+use std::sync::{Arc, LazyLock};
use strum::IntoEnumIterator;
use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
@@ -31,7 +31,6 @@ static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
const CODESTRAL_API_KEY_ENV_VAR_NAME: &str = "CODESTRAL_API_KEY";
static CODESTRAL_API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(CODESTRAL_API_KEY_ENV_VAR_NAME);
-static CODESTRAL_API_KEY: OnceLock<Entity<ApiKeyState>> = OnceLock::new();
#[derive(Default, Clone, Debug, PartialEq)]
pub struct MistralSettings {
@@ -49,14 +48,18 @@ pub struct State {
codestral_api_key_state: Entity<ApiKeyState>,
}
+struct CodestralApiKey(Entity<ApiKeyState>);
+impl Global for CodestralApiKey {}
+
pub fn codestral_api_key(cx: &mut App) -> Entity<ApiKeyState> {
- return CODESTRAL_API_KEY
- .get_or_init(|| {
- cx.new(|_| {
- ApiKeyState::new(CODESTRAL_API_URL.into(), CODESTRAL_API_KEY_ENV_VAR.clone())
- })
- })
- .clone();
+ if cx.has_global::<CodestralApiKey>() {
+ cx.global::<CodestralApiKey>().0.clone()
+ } else {
+ let api_key_state = cx
+ .new(|_| ApiKeyState::new(CODESTRAL_API_URL.into(), CODESTRAL_API_KEY_ENV_VAR.clone()));
+ cx.set_global(CodestralApiKey(api_key_state.clone()));
+ api_key_state
+ }
}
impl State {