api_key.rs

  1use anyhow::{Result, anyhow};
  2use credentials_provider::CredentialsProvider;
  3use futures::{FutureExt, future};
  4use gpui::{AsyncApp, Context, SharedString, Task};
  5use language_model::AuthenticateError;
  6use std::{
  7    fmt::{Display, Formatter},
  8    sync::Arc,
  9};
 10use util::ResultExt as _;
 11use zed_env_vars::EnvVar;
 12
 13/// Manages a single API key for a language model provider. API keys either come from environment
 14/// variables or the system keychain.
 15///
 16/// Keys from the system keychain are associated with a provider URL, and this ensures that they are
 17/// only used with that URL.
 18pub struct ApiKeyState {
 19    url: SharedString,
 20    load_status: LoadStatus,
 21    load_task: Option<future::Shared<Task<()>>>,
 22}
 23
 24#[derive(Debug, Clone)]
 25pub enum LoadStatus {
 26    NotPresent,
 27    Error(String),
 28    Loaded(ApiKey),
 29}
 30
 31#[derive(Debug, Clone)]
 32pub struct ApiKey {
 33    source: ApiKeySource,
 34    key: Arc<str>,
 35}
 36
 37impl ApiKeyState {
 38    pub fn new(url: SharedString) -> Self {
 39        Self {
 40            url,
 41            load_status: LoadStatus::NotPresent,
 42            load_task: None,
 43        }
 44    }
 45
 46    pub fn has_key(&self) -> bool {
 47        matches!(self.load_status, LoadStatus::Loaded { .. })
 48    }
 49
 50    pub fn is_from_env_var(&self) -> bool {
 51        match &self.load_status {
 52            LoadStatus::Loaded(ApiKey {
 53                source: ApiKeySource::EnvVar { .. },
 54                ..
 55            }) => true,
 56            _ => false,
 57        }
 58    }
 59
 60    /// Get the stored API key, verifying that it is associated with the URL. Returns `None` if
 61    /// there is no key or for URL mismatches, and the mismatch case is logged.
 62    ///
 63    /// To avoid URL mismatches, expects that `load_if_needed` or `handle_url_change` has been
 64    /// called with this URL.
 65    pub fn key(&self, url: &str) -> Option<Arc<str>> {
 66        let api_key = match &self.load_status {
 67            LoadStatus::Loaded(api_key) => api_key,
 68            _ => return None,
 69        };
 70        if url == self.url.as_str() {
 71            Some(api_key.key.clone())
 72        } else if let ApiKeySource::EnvVar(var_name) = &api_key.source {
 73            log::warn!(
 74                "{} is now being used with URL {}, when initially it was used with URL {}",
 75                var_name,
 76                url,
 77                self.url
 78            );
 79            Some(api_key.key.clone())
 80        } else {
 81            // bug case because load_if_needed should be called whenever the url may have changed
 82            log::error!(
 83                "bug: Attempted to use API key associated with URL {} instead with URL {}",
 84                self.url,
 85                url
 86            );
 87            None
 88        }
 89    }
 90
 91    /// Set or delete the API key in the system keychain.
 92    pub fn store<Ent: 'static>(
 93        &mut self,
 94        url: SharedString,
 95        key: Option<String>,
 96        get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
 97        cx: &Context<Ent>,
 98    ) -> Task<Result<()>> {
 99        if self.is_from_env_var() {
100            return Task::ready(Err(anyhow!(
101                "bug: attempted to store API key in system keychain when API key is from env var",
102            )));
103        }
104        let credentials_provider = <dyn CredentialsProvider>::global(cx);
105        cx.spawn(async move |ent, cx| {
106            if let Some(key) = &key {
107                credentials_provider
108                    .write_credentials(&url, "Bearer", key.as_bytes(), cx)
109                    .await
110                    .log_err();
111            } else {
112                credentials_provider
113                    .delete_credentials(&url, cx)
114                    .await
115                    .log_err();
116            }
117            ent.update(cx, |ent, cx| {
118                let this = get_this(ent);
119                this.url = url;
120                this.load_status = match &key {
121                    Some(key) => LoadStatus::Loaded(ApiKey {
122                        source: ApiKeySource::SystemKeychain,
123                        key: key.as_str().into(),
124                    }),
125                    None => LoadStatus::NotPresent,
126                };
127                cx.notify();
128            })
129        })
130    }
131
132    /// Reloads the API key if the current API key is associated with a different URL.
133    ///
134    /// Note that it is not efficient to use this or `load_if_needed` with multiple URLs
135    /// interchangeably - URL change should correspond to some user initiated change.
136    pub fn handle_url_change<Ent: 'static>(
137        &mut self,
138        url: SharedString,
139        env_var: &EnvVar,
140        get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
141        cx: &mut Context<Ent>,
142    ) {
143        if url != self.url {
144            if !self.is_from_env_var() {
145                // loading will continue even though this result task is dropped
146                let _task = self.load_if_needed(url, env_var, get_this, cx);
147            }
148        }
149    }
150
151    /// If needed, loads the API key associated with the given URL from the system keychain. When a
152    /// non-empty environment variable is provided, it will be used instead. If called when an API
153    /// key was already loaded for a different URL, that key will be cleared before loading.
154    ///
155    /// Dropping the returned Task does not cancel key loading.
156    pub fn load_if_needed<Ent: 'static>(
157        &mut self,
158        url: SharedString,
159        env_var: &EnvVar,
160        get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
161        cx: &mut Context<Ent>,
162    ) -> Task<Result<(), AuthenticateError>> {
163        if let LoadStatus::Loaded { .. } = &self.load_status
164            && self.url == url
165        {
166            return Task::ready(Ok(()));
167        }
168
169        if let Some(key) = &env_var.value
170            && !key.is_empty()
171        {
172            let api_key = ApiKey::from_env(env_var.name.clone(), key);
173            self.url = url;
174            self.load_status = LoadStatus::Loaded(api_key);
175            self.load_task = None;
176            cx.notify();
177            return Task::ready(Ok(()));
178        }
179
180        let task = if let Some(load_task) = &self.load_task {
181            load_task.clone()
182        } else {
183            let load_task = Self::load(url.clone(), get_this.clone(), cx).shared();
184            self.url = url;
185            self.load_status = LoadStatus::NotPresent;
186            self.load_task = Some(load_task.clone());
187            cx.notify();
188            load_task
189        };
190
191        cx.spawn(async move |ent, cx| {
192            task.await;
193            ent.update(cx, |ent, _cx| {
194                get_this(ent).load_status.clone().into_authenticate_result()
195            })
196            .ok();
197            Ok(())
198        })
199    }
200
201    fn load<Ent: 'static>(
202        url: SharedString,
203        get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
204        cx: &Context<Ent>,
205    ) -> Task<()> {
206        let credentials_provider = <dyn CredentialsProvider>::global(cx);
207        cx.spawn({
208            async move |ent, cx| {
209                let load_status =
210                    ApiKey::load_from_system_keychain_impl(&url, credentials_provider.as_ref(), cx)
211                        .await;
212                ent.update(cx, |ent, cx| {
213                    let this = get_this(ent);
214                    this.url = url;
215                    this.load_status = load_status;
216                    this.load_task = None;
217                    cx.notify();
218                })
219                .ok();
220            }
221        })
222    }
223}
224
225impl ApiKey {
226    pub fn from_env(env_var_name: SharedString, key: &str) -> Self {
227        Self {
228            source: ApiKeySource::EnvVar(env_var_name),
229            key: key.into(),
230        }
231    }
232
233    async fn load_from_system_keychain_impl(
234        url: &str,
235        credentials_provider: &dyn CredentialsProvider,
236        cx: &AsyncApp,
237    ) -> LoadStatus {
238        if url.is_empty() {
239            return LoadStatus::NotPresent;
240        }
241        let read_result = credentials_provider.read_credentials(&url, cx).await;
242        let api_key = match read_result {
243            Ok(Some((_, api_key))) => api_key,
244            Ok(None) => return LoadStatus::NotPresent,
245            Err(err) => return LoadStatus::Error(err.to_string()),
246        };
247        let key = match str::from_utf8(&api_key) {
248            Ok(key) => key,
249            Err(_) => return LoadStatus::Error(format!("API key for URL {url} is not utf8")),
250        };
251        LoadStatus::Loaded(Self {
252            source: ApiKeySource::SystemKeychain,
253            key: key.into(),
254        })
255    }
256}
257
258impl LoadStatus {
259    fn into_authenticate_result(self) -> Result<ApiKey, AuthenticateError> {
260        match self {
261            LoadStatus::Loaded(api_key) => Ok(api_key),
262            LoadStatus::NotPresent => Err(AuthenticateError::CredentialsNotFound),
263            LoadStatus::Error(err) => Err(AuthenticateError::Other(anyhow!(err))),
264        }
265    }
266}
267
268#[derive(Debug, Clone)]
269enum ApiKeySource {
270    EnvVar(SharedString),
271    SystemKeychain,
272}
273
274impl Display for ApiKeySource {
275    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
276        match self {
277            ApiKeySource::EnvVar(var) => write!(f, "environment variable {}", var),
278            ApiKeySource::SystemKeychain => write!(f, "system keychain"),
279        }
280    }
281}