api_key.rs

  1use anyhow::{Result, anyhow};
  2use credentials_provider::CredentialsProvider;
  3use env_var::EnvVar;
  4use futures::{FutureExt, future};
  5use gpui::{AsyncApp, Context, SharedString, Task};
  6use std::{
  7    fmt::{Display, Formatter},
  8    sync::Arc,
  9};
 10use util::ResultExt as _;
 11
 12use crate::AuthenticateError;
 13
 14/// Manages a single API key for a language model provider. API keys either come from environment
 15/// variables or the system keychain.
 16///
 17/// Keys from the system keychain are associated with a provider URL, and this ensures that they are
 18/// only used with that URL.
 19pub struct ApiKeyState {
 20    pub url: SharedString,
 21    env_var: EnvVar,
 22    load_status: LoadStatus,
 23    load_task: Option<future::Shared<Task<()>>>,
 24}
 25
 26#[derive(Debug, Clone)]
 27pub enum LoadStatus {
 28    NotPresent,
 29    Error(String),
 30    Loaded(ApiKey),
 31}
 32
 33#[derive(Debug, Clone)]
 34pub struct ApiKey {
 35    source: ApiKeySource,
 36    key: Arc<str>,
 37}
 38
 39impl ApiKeyState {
 40    pub fn new(url: SharedString, env_var: EnvVar) -> Self {
 41        Self {
 42            url,
 43            env_var,
 44            load_status: LoadStatus::NotPresent,
 45            load_task: None,
 46        }
 47    }
 48
 49    pub fn has_key(&self) -> bool {
 50        matches!(self.load_status, LoadStatus::Loaded { .. })
 51    }
 52
 53    pub fn env_var_name(&self) -> &SharedString {
 54        &self.env_var.name
 55    }
 56
 57    pub fn is_from_env_var(&self) -> bool {
 58        match &self.load_status {
 59            LoadStatus::Loaded(ApiKey {
 60                source: ApiKeySource::EnvVar { .. },
 61                ..
 62            }) => true,
 63            _ => false,
 64        }
 65    }
 66
 67    /// Get the stored API key, verifying that it is associated with the URL. Returns `None` if
 68    /// there is no key or for URL mismatches, and the mismatch case is logged.
 69    ///
 70    /// To avoid URL mismatches, expects that `load_if_needed` or `handle_url_change` has been
 71    /// called with this URL.
 72    pub fn key(&self, url: &str) -> Option<Arc<str>> {
 73        let api_key = match &self.load_status {
 74            LoadStatus::Loaded(api_key) => api_key,
 75            _ => return None,
 76        };
 77        if url == self.url.as_str() {
 78            Some(api_key.key.clone())
 79        } else if let ApiKeySource::EnvVar(var_name) = &api_key.source {
 80            log::warn!(
 81                "{} is now being used with URL {}, when initially it was used with URL {}",
 82                var_name,
 83                url,
 84                self.url
 85            );
 86            Some(api_key.key.clone())
 87        } else {
 88            // bug case because load_if_needed should be called whenever the url may have changed
 89            log::error!(
 90                "bug: Attempted to use API key associated with URL {} instead with URL {}",
 91                self.url,
 92                url
 93            );
 94            None
 95        }
 96    }
 97
 98    /// Set or delete the API key in the system keychain.
 99    pub fn store<Ent: 'static>(
100        &mut self,
101        url: SharedString,
102        key: Option<String>,
103        get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
104        provider: Arc<dyn CredentialsProvider>,
105        cx: &Context<Ent>,
106    ) -> Task<Result<()>> {
107        if self.is_from_env_var() {
108            return Task::ready(Err(anyhow!(
109                "bug: attempted to store API key in system keychain when API key is from env var",
110            )));
111        }
112        cx.spawn(async move |ent, cx| {
113            if let Some(key) = &key {
114                provider
115                    .write_credentials(&url, "Bearer", key.as_bytes(), cx)
116                    .await
117                    .log_err();
118            } else {
119                provider.delete_credentials(&url, cx).await.log_err();
120            }
121            ent.update(cx, |ent, cx| {
122                let this = get_this(ent);
123                this.url = url;
124                this.load_status = match &key {
125                    Some(key) => LoadStatus::Loaded(ApiKey {
126                        source: ApiKeySource::SystemKeychain,
127                        key: key.as_str().into(),
128                    }),
129                    None => LoadStatus::NotPresent,
130                };
131                cx.notify();
132            })
133        })
134    }
135
136    /// Reloads the API key if the current API key is associated with a different URL.
137    ///
138    /// Note that it is not efficient to use this or `load_if_needed` with multiple URLs
139    /// interchangeably - URL change should correspond to some user initiated change.
140    pub fn handle_url_change<Ent: 'static>(
141        &mut self,
142        url: SharedString,
143        get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
144        provider: Arc<dyn CredentialsProvider>,
145        cx: &mut Context<Ent>,
146    ) {
147        if url != self.url {
148            if !self.is_from_env_var() {
149                // loading will continue even though this result task is dropped
150                let _task = self.load_if_needed(url, get_this, provider, cx);
151            }
152        }
153    }
154
155    /// If needed, loads the API key associated with the given URL from the system keychain. When a
156    /// non-empty environment variable is provided, it will be used instead. If called when an API
157    /// key was already loaded for a different URL, that key will be cleared before loading.
158    ///
159    /// Dropping the returned Task does not cancel key loading.
160    pub fn load_if_needed<Ent: 'static>(
161        &mut self,
162        url: SharedString,
163        get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
164        provider: Arc<dyn CredentialsProvider>,
165        cx: &mut Context<Ent>,
166    ) -> Task<Result<(), AuthenticateError>> {
167        if let LoadStatus::Loaded { .. } = &self.load_status
168            && self.url == url
169        {
170            return Task::ready(Ok(()));
171        }
172
173        if let Some(key) = &self.env_var.value
174            && !key.is_empty()
175        {
176            let api_key = ApiKey::from_env(self.env_var.name.clone(), key);
177            self.url = url;
178            self.load_status = LoadStatus::Loaded(api_key);
179            self.load_task = None;
180            cx.notify();
181            return Task::ready(Ok(()));
182        }
183
184        let task = if let Some(load_task) = &self.load_task {
185            load_task.clone()
186        } else {
187            let load_task = Self::load(url.clone(), get_this.clone(), provider, cx).shared();
188            self.url = url;
189            self.load_status = LoadStatus::NotPresent;
190            self.load_task = Some(load_task.clone());
191            cx.notify();
192            load_task
193        };
194
195        cx.spawn(async move |ent, cx| {
196            task.await;
197            ent.update(cx, |ent, _cx| {
198                get_this(ent).load_status.clone().into_authenticate_result()
199            })
200            .ok();
201            Ok(())
202        })
203    }
204
205    fn load<Ent: 'static>(
206        url: SharedString,
207        get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
208        provider: Arc<dyn CredentialsProvider>,
209        cx: &Context<Ent>,
210    ) -> Task<()> {
211        cx.spawn({
212            async move |ent, cx| {
213                let load_status =
214                    ApiKey::load_from_system_keychain_impl(&url, provider.as_ref(), cx).await;
215                ent.update(cx, |ent, cx| {
216                    let this = get_this(ent);
217                    this.url = url;
218                    this.load_status = load_status;
219                    this.load_task = None;
220                    cx.notify();
221                })
222                .ok();
223            }
224        })
225    }
226}
227
228impl ApiKey {
229    pub fn key(&self) -> &str {
230        &self.key
231    }
232
233    pub fn from_env(env_var_name: SharedString, key: &str) -> Self {
234        Self {
235            source: ApiKeySource::EnvVar(env_var_name),
236            key: key.into(),
237        }
238    }
239
240    pub async fn load_from_system_keychain(
241        url: &str,
242        credentials_provider: &dyn CredentialsProvider,
243        cx: &AsyncApp,
244    ) -> Result<Self, AuthenticateError> {
245        Self::load_from_system_keychain_impl(url, credentials_provider, cx)
246            .await
247            .into_authenticate_result()
248    }
249
250    async fn load_from_system_keychain_impl(
251        url: &str,
252        credentials_provider: &dyn CredentialsProvider,
253        cx: &AsyncApp,
254    ) -> LoadStatus {
255        if url.is_empty() {
256            return LoadStatus::NotPresent;
257        }
258        let read_result = credentials_provider.read_credentials(&url, cx).await;
259        let api_key = match read_result {
260            Ok(Some((_, api_key))) => api_key,
261            Ok(None) => return LoadStatus::NotPresent,
262            Err(err) => return LoadStatus::Error(err.to_string()),
263        };
264        let key = match str::from_utf8(&api_key) {
265            Ok(key) => key,
266            Err(_) => return LoadStatus::Error(format!("API key for URL {url} is not utf8")),
267        };
268        LoadStatus::Loaded(Self {
269            source: ApiKeySource::SystemKeychain,
270            key: key.into(),
271        })
272    }
273}
274
275impl LoadStatus {
276    fn into_authenticate_result(self) -> Result<ApiKey, AuthenticateError> {
277        match self {
278            LoadStatus::Loaded(api_key) => Ok(api_key),
279            LoadStatus::NotPresent => Err(AuthenticateError::CredentialsNotFound),
280            LoadStatus::Error(err) => Err(AuthenticateError::Other(anyhow!(err))),
281        }
282    }
283}
284
285#[derive(Debug, Clone)]
286enum ApiKeySource {
287    EnvVar(SharedString),
288    SystemKeychain,
289}
290
291impl Display for ApiKeySource {
292    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
293        match self {
294            ApiKeySource::EnvVar(var) => write!(f, "environment variable {}", var),
295            ApiKeySource::SystemKeychain => write!(f, "system keychain"),
296        }
297    }
298}