diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index cfd43033515e2c3527c8d0dfbf1267fb96793819..fca1cf977cb5e3b32dc6f2335fb0d9188979bc9f 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -56,13 +56,13 @@ pub struct OpenAiLanguageModelProvider { pub struct State { api_key: Option, api_key_from_env: bool, + last_api_url: String, _subscription: Subscription, } const OPENAI_API_KEY_VAR: &str = "OPENAI_API_KEY"; impl State { - // fn is_authenticated(&self) -> bool { self.api_key.is_some() } @@ -104,11 +104,7 @@ impl State { }) } - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - + fn get_api_key(&self, cx: &mut Context) -> Task> { let credentials_provider = ::global(cx); let api_url = AllLanguageModelSettings::get_global(cx) .openai @@ -136,14 +132,52 @@ impl State { Ok(()) }) } + + fn authenticate(&self, cx: &mut Context) -> Task> { + if self.is_authenticated() { + return Task::ready(Ok(())); + } + + self.get_api_key(cx) + } } impl OpenAiLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { + let initial_api_url = AllLanguageModelSettings::get_global(cx) + .openai + .api_url + .clone(); + let state = cx.new(|cx| State { api_key: None, api_key_from_env: false, - _subscription: cx.observe_global::(|_this: &mut State, cx| { + last_api_url: initial_api_url.clone(), + _subscription: cx.observe_global::(|this: &mut State, cx| { + let current_api_url = AllLanguageModelSettings::get_global(cx) + .openai + .api_url + .clone(); + + if this.last_api_url != current_api_url { + this.last_api_url = current_api_url; + if !this.api_key_from_env { + this.api_key = None; + let spawn_task = cx.spawn(async move |handle, cx| { + if let Ok(task) = handle.update(cx, |this, cx| this.get_api_key(cx)) { + if let Err(_) = task.await { + handle + .update(cx, |this, _| { + this.api_key = None; + this.api_key_from_env = false; + }) + .ok(); + } + } + }); + spawn_task.detach(); + } + } cx.notify(); }), }); diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index 789eb00a5746c729103f77a1e92d0e58fc4c1ab0..4ebb11a07b66ec7054ca65437ec887a415fa3f5c 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -113,11 +113,7 @@ impl State { }) } - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - + fn get_api_key(&self, cx: &mut Context) -> Task> { let credentials_provider = ::global(cx); let env_var_name = self.env_var_name.clone(); let api_url = self.settings.api_url.clone(); @@ -143,6 +139,14 @@ impl State { Ok(()) }) } + + fn authenticate(&self, cx: &mut Context) -> Task> { + if self.is_authenticated() { + return Task::ready(Ok(())); + } + + self.get_api_key(cx) + } } impl OpenAiCompatibleLanguageModelProvider { @@ -160,11 +164,27 @@ impl OpenAiCompatibleLanguageModelProvider { api_key: None, api_key_from_env: false, _subscription: cx.observe_global::(|this: &mut State, cx| { - let Some(settings) = resolve_settings(&this.id, cx) else { + let Some(settings) = resolve_settings(&this.id, cx).cloned() else { return; }; - if &this.settings != settings { - this.settings = settings.clone(); + if &this.settings != &settings { + if settings.api_url != this.settings.api_url && !this.api_key_from_env { + let spawn_task = cx.spawn(async move |handle, cx| { + if let Ok(task) = handle.update(cx, |this, cx| this.get_api_key(cx)) { + if let Err(_) = task.await { + handle + .update(cx, |this, _| { + this.api_key = None; + this.api_key_from_env = false; + }) + .ok(); + } + } + }); + spawn_task.detach(); + } + + this.settings = settings; cx.notify(); } }),