1use anyhow::{Result, anyhow};
  2use credentials_provider::CredentialsProvider;
  3use futures::{FutureExt, future};
  4use gpui::{AsyncApp, Context, SharedString, Task};
  5use language_model::AuthenticateError;
  6use std::{
  7    fmt::{Display, Formatter},
  8    sync::Arc,
  9};
 10use util::ResultExt as _;
 11use zed_env_vars::EnvVar;
 12
 13/// Manages a single API key for a language model provider. API keys either come from environment
 14/// variables or the system keychain.
 15///
 16/// Keys from the system keychain are associated with a provider URL, and this ensures that they are
 17/// only used with that URL.
 18pub struct ApiKeyState {
 19    url: SharedString,
 20    load_status: LoadStatus,
 21    load_task: Option<future::Shared<Task<()>>>,
 22}
 23
 24#[derive(Debug, Clone)]
 25pub enum LoadStatus {
 26    NotPresent,
 27    Error(String),
 28    Loaded(ApiKey),
 29}
 30
 31#[derive(Debug, Clone)]
 32pub struct ApiKey {
 33    source: ApiKeySource,
 34    key: Arc<str>,
 35}
 36
 37impl ApiKeyState {
 38    pub fn new(url: SharedString) -> Self {
 39        Self {
 40            url,
 41            load_status: LoadStatus::NotPresent,
 42            load_task: None,
 43        }
 44    }
 45
 46    pub fn has_key(&self) -> bool {
 47        matches!(self.load_status, LoadStatus::Loaded { .. })
 48    }
 49
 50    pub fn is_from_env_var(&self) -> bool {
 51        match &self.load_status {
 52            LoadStatus::Loaded(ApiKey {
 53                source: ApiKeySource::EnvVar { .. },
 54                ..
 55            }) => true,
 56            _ => false,
 57        }
 58    }
 59
 60    /// Get the stored API key, verifying that it is associated with the URL. Returns `None` if
 61    /// there is no key or for URL mismatches, and the mismatch case is logged.
 62    ///
 63    /// To avoid URL mismatches, expects that `load_if_needed` or `handle_url_change` has been
 64    /// called with this URL.
 65    pub fn key(&self, url: &str) -> Option<Arc<str>> {
 66        let api_key = match &self.load_status {
 67            LoadStatus::Loaded(api_key) => api_key,
 68            _ => return None,
 69        };
 70        if url == self.url.as_str() {
 71            Some(api_key.key.clone())
 72        } else if let ApiKeySource::EnvVar(var_name) = &api_key.source {
 73            log::warn!(
 74                "{} is now being used with URL {}, when initially it was used with URL {}",
 75                var_name,
 76                url,
 77                self.url
 78            );
 79            Some(api_key.key.clone())
 80        } else {
 81            // bug case because load_if_needed should be called whenever the url may have changed
 82            log::error!(
 83                "bug: Attempted to use API key associated with URL {} instead with URL {}",
 84                self.url,
 85                url
 86            );
 87            None
 88        }
 89    }
 90
 91    /// Set or delete the API key in the system keychain.
 92    pub fn store<Ent: 'static>(
 93        &mut self,
 94        url: SharedString,
 95        key: Option<String>,
 96        get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
 97        cx: &Context<Ent>,
 98    ) -> Task<Result<()>> {
 99        if self.is_from_env_var() {
100            return Task::ready(Err(anyhow!(
101                "bug: attempted to store API key in system keychain when API key is from env var",
102            )));
103        }
104        let credentials_provider = <dyn CredentialsProvider>::global(cx);
105        cx.spawn(async move |ent, cx| {
106            if let Some(key) = &key {
107                credentials_provider
108                    .write_credentials(&url, "Bearer", key.as_bytes(), cx)
109                    .await
110                    .log_err();
111            } else {
112                credentials_provider
113                    .delete_credentials(&url, cx)
114                    .await
115                    .log_err();
116            }
117            ent.update(cx, |ent, cx| {
118                let this = get_this(ent);
119                this.url = url;
120                this.load_status = match &key {
121                    Some(key) => LoadStatus::Loaded(ApiKey {
122                        source: ApiKeySource::SystemKeychain,
123                        key: key.as_str().into(),
124                    }),
125                    None => LoadStatus::NotPresent,
126                };
127                cx.notify();
128            })
129        })
130    }
131
132    /// Reloads the API key if the current API key is associated with a different URL.
133    ///
134    /// Note that it is not efficient to use this or `load_if_needed` with multiple URLs
135    /// interchangeably - URL change should correspond to some user initiated change.
136    pub fn handle_url_change<Ent: 'static>(
137        &mut self,
138        url: SharedString,
139        env_var: &EnvVar,
140        get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
141        cx: &mut Context<Ent>,
142    ) {
143        if url != self.url {
144            if !self.is_from_env_var() {
145                // loading will continue even though this result task is dropped
146                let _task = self.load_if_needed(url, env_var, get_this, cx);
147            }
148        }
149    }
150
151    /// If needed, loads the API key associated with the given URL from the system keychain. When a
152    /// non-empty environment variable is provided, it will be used instead. If called when an API
153    /// key was already loaded for a different URL, that key will be cleared before loading.
154    ///
155    /// Dropping the returned Task does not cancel key loading.
156    pub fn load_if_needed<Ent: 'static>(
157        &mut self,
158        url: SharedString,
159        env_var: &EnvVar,
160        get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
161        cx: &mut Context<Ent>,
162    ) -> Task<Result<(), AuthenticateError>> {
163        if let LoadStatus::Loaded { .. } = &self.load_status
164            && self.url == url
165        {
166            return Task::ready(Ok(()));
167        }
168
169        if let Some(key) = &env_var.value
170            && !key.is_empty()
171        {
172            let api_key = ApiKey::from_env(env_var.name.clone(), key);
173            self.url = url;
174            self.load_status = LoadStatus::Loaded(api_key);
175            self.load_task = None;
176            cx.notify();
177            return Task::ready(Ok(()));
178        }
179
180        let task = if let Some(load_task) = &self.load_task {
181            load_task.clone()
182        } else {
183            let load_task = Self::load(url.clone(), get_this.clone(), cx).shared();
184            self.url = url;
185            self.load_status = LoadStatus::NotPresent;
186            self.load_task = Some(load_task.clone());
187            cx.notify();
188            load_task
189        };
190
191        cx.spawn(async move |ent, cx| {
192            task.await;
193            ent.update(cx, |ent, _cx| {
194                get_this(ent).load_status.clone().into_authenticate_result()
195            })
196            .ok();
197            Ok(())
198        })
199    }
200
201    fn load<Ent: 'static>(
202        url: SharedString,
203        get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
204        cx: &Context<Ent>,
205    ) -> Task<()> {
206        let credentials_provider = <dyn CredentialsProvider>::global(cx);
207        cx.spawn({
208            async move |ent, cx| {
209                let load_status =
210                    ApiKey::load_from_system_keychain_impl(&url, credentials_provider.as_ref(), cx)
211                        .await;
212                ent.update(cx, |ent, cx| {
213                    let this = get_this(ent);
214                    this.url = url;
215                    this.load_status = load_status;
216                    this.load_task = None;
217                    cx.notify();
218                })
219                .ok();
220            }
221        })
222    }
223}
224
225impl ApiKey {
226    pub fn key(&self) -> &str {
227        &self.key
228    }
229
230    pub fn from_env(env_var_name: SharedString, key: &str) -> Self {
231        Self {
232            source: ApiKeySource::EnvVar(env_var_name),
233            key: key.into(),
234        }
235    }
236
237    pub async fn load_from_system_keychain(
238        url: &str,
239        credentials_provider: &dyn CredentialsProvider,
240        cx: &AsyncApp,
241    ) -> Result<Self, AuthenticateError> {
242        Self::load_from_system_keychain_impl(url, credentials_provider, cx)
243            .await
244            .into_authenticate_result()
245    }
246
247    async fn load_from_system_keychain_impl(
248        url: &str,
249        credentials_provider: &dyn CredentialsProvider,
250        cx: &AsyncApp,
251    ) -> LoadStatus {
252        if url.is_empty() {
253            return LoadStatus::NotPresent;
254        }
255        let read_result = credentials_provider.read_credentials(&url, cx).await;
256        let api_key = match read_result {
257            Ok(Some((_, api_key))) => api_key,
258            Ok(None) => return LoadStatus::NotPresent,
259            Err(err) => return LoadStatus::Error(err.to_string()),
260        };
261        let key = match str::from_utf8(&api_key) {
262            Ok(key) => key,
263            Err(_) => return LoadStatus::Error(format!("API key for URL {url} is not utf8")),
264        };
265        LoadStatus::Loaded(Self {
266            source: ApiKeySource::SystemKeychain,
267            key: key.into(),
268        })
269    }
270}
271
272impl LoadStatus {
273    fn into_authenticate_result(self) -> Result<ApiKey, AuthenticateError> {
274        match self {
275            LoadStatus::Loaded(api_key) => Ok(api_key),
276            LoadStatus::NotPresent => Err(AuthenticateError::CredentialsNotFound),
277            LoadStatus::Error(err) => Err(AuthenticateError::Other(anyhow!(err))),
278        }
279    }
280}
281
282#[derive(Debug, Clone)]
283enum ApiKeySource {
284    EnvVar(SharedString),
285    SystemKeychain,
286}
287
288impl Display for ApiKeySource {
289    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
290        match self {
291            ApiKeySource::EnvVar(var) => write!(f, "environment variable {}", var),
292            ApiKeySource::SystemKeychain => write!(f, "system keychain"),
293        }
294    }
295}