api_key.rs

  1use anyhow::{Result, anyhow};
  2use credentials_provider::CredentialsProvider;
  3use futures::{FutureExt, future};
  4use gpui::{AsyncApp, Context, SharedString, Task};
  5use std::{
  6    fmt::{Display, Formatter},
  7    sync::Arc,
  8};
  9use util::ResultExt as _;
 10use zed_env_vars::EnvVar;
 11
 12use crate::AuthenticateError;
 13
 14/// Manages a single API key for a language model provider. API keys either come from environment
 15/// variables or the system keychain.
 16///
 17/// Keys from the system keychain are associated with a provider URL, and this ensures that they are
 18/// only used with that URL.
 19pub struct ApiKeyState {
 20    pub url: SharedString,
 21    env_var: EnvVar,
 22    load_status: LoadStatus,
 23    load_task: Option<future::Shared<Task<()>>>,
 24}
 25
 26#[derive(Debug, Clone)]
 27pub enum LoadStatus {
 28    NotPresent,
 29    Error(String),
 30    Loaded(ApiKey),
 31}
 32
 33#[derive(Debug, Clone)]
 34pub struct ApiKey {
 35    source: ApiKeySource,
 36    key: Arc<str>,
 37}
 38
 39impl ApiKeyState {
 40    pub fn new(url: SharedString, env_var: EnvVar) -> Self {
 41        Self {
 42            url,
 43            env_var,
 44            load_status: LoadStatus::NotPresent,
 45            load_task: None,
 46        }
 47    }
 48
 49    pub fn has_key(&self) -> bool {
 50        matches!(self.load_status, LoadStatus::Loaded { .. })
 51    }
 52
 53    pub fn env_var_name(&self) -> &SharedString {
 54        &self.env_var.name
 55    }
 56
 57    pub fn is_from_env_var(&self) -> bool {
 58        match &self.load_status {
 59            LoadStatus::Loaded(ApiKey {
 60                source: ApiKeySource::EnvVar { .. },
 61                ..
 62            }) => true,
 63            _ => false,
 64        }
 65    }
 66
 67    /// Get the stored API key, verifying that it is associated with the URL. Returns `None` if
 68    /// there is no key or for URL mismatches, and the mismatch case is logged.
 69    ///
 70    /// To avoid URL mismatches, expects that `load_if_needed` or `handle_url_change` has been
 71    /// called with this URL.
 72    pub fn key(&self, url: &str) -> Option<Arc<str>> {
 73        let api_key = match &self.load_status {
 74            LoadStatus::Loaded(api_key) => api_key,
 75            _ => return None,
 76        };
 77        if url == self.url.as_str() {
 78            Some(api_key.key.clone())
 79        } else if let ApiKeySource::EnvVar(var_name) = &api_key.source {
 80            log::warn!(
 81                "{} is now being used with URL {}, when initially it was used with URL {}",
 82                var_name,
 83                url,
 84                self.url
 85            );
 86            Some(api_key.key.clone())
 87        } else {
 88            // bug case because load_if_needed should be called whenever the url may have changed
 89            log::error!(
 90                "bug: Attempted to use API key associated with URL {} instead with URL {}",
 91                self.url,
 92                url
 93            );
 94            None
 95        }
 96    }
 97
 98    /// Set or delete the API key in the system keychain.
 99    pub fn store<Ent: 'static>(
100        &mut self,
101        url: SharedString,
102        key: Option<String>,
103        get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
104        cx: &Context<Ent>,
105    ) -> Task<Result<()>> {
106        if self.is_from_env_var() {
107            return Task::ready(Err(anyhow!(
108                "bug: attempted to store API key in system keychain when API key is from env var",
109            )));
110        }
111        let credentials_provider = <dyn CredentialsProvider>::global(cx);
112        cx.spawn(async move |ent, cx| {
113            if let Some(key) = &key {
114                credentials_provider
115                    .write_credentials(&url, "Bearer", key.as_bytes(), cx)
116                    .await
117                    .log_err();
118            } else {
119                credentials_provider
120                    .delete_credentials(&url, cx)
121                    .await
122                    .log_err();
123            }
124            ent.update(cx, |ent, cx| {
125                let this = get_this(ent);
126                this.url = url;
127                this.load_status = match &key {
128                    Some(key) => LoadStatus::Loaded(ApiKey {
129                        source: ApiKeySource::SystemKeychain,
130                        key: key.as_str().into(),
131                    }),
132                    None => LoadStatus::NotPresent,
133                };
134                cx.notify();
135            })
136        })
137    }
138
139    /// Reloads the API key if the current API key is associated with a different URL.
140    ///
141    /// Note that it is not efficient to use this or `load_if_needed` with multiple URLs
142    /// interchangeably - URL change should correspond to some user initiated change.
143    pub fn handle_url_change<Ent: 'static>(
144        &mut self,
145        url: SharedString,
146        get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
147        cx: &mut Context<Ent>,
148    ) {
149        if url != self.url {
150            if !self.is_from_env_var() {
151                // loading will continue even though this result task is dropped
152                let _task = self.load_if_needed(url, get_this, cx);
153            }
154        }
155    }
156
157    /// If needed, loads the API key associated with the given URL from the system keychain. When a
158    /// non-empty environment variable is provided, it will be used instead. If called when an API
159    /// key was already loaded for a different URL, that key will be cleared before loading.
160    ///
161    /// Dropping the returned Task does not cancel key loading.
162    pub fn load_if_needed<Ent: 'static>(
163        &mut self,
164        url: SharedString,
165        get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
166        cx: &mut Context<Ent>,
167    ) -> Task<Result<(), AuthenticateError>> {
168        if let LoadStatus::Loaded { .. } = &self.load_status
169            && self.url == url
170        {
171            return Task::ready(Ok(()));
172        }
173
174        if let Some(key) = &self.env_var.value
175            && !key.is_empty()
176        {
177            let api_key = ApiKey::from_env(self.env_var.name.clone(), key);
178            self.url = url;
179            self.load_status = LoadStatus::Loaded(api_key);
180            self.load_task = None;
181            cx.notify();
182            return Task::ready(Ok(()));
183        }
184
185        let task = if let Some(load_task) = &self.load_task {
186            load_task.clone()
187        } else {
188            let load_task = Self::load(url.clone(), get_this.clone(), cx).shared();
189            self.url = url;
190            self.load_status = LoadStatus::NotPresent;
191            self.load_task = Some(load_task.clone());
192            cx.notify();
193            load_task
194        };
195
196        cx.spawn(async move |ent, cx| {
197            task.await;
198            ent.update(cx, |ent, _cx| {
199                get_this(ent).load_status.clone().into_authenticate_result()
200            })
201            .ok();
202            Ok(())
203        })
204    }
205
206    fn load<Ent: 'static>(
207        url: SharedString,
208        get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
209        cx: &Context<Ent>,
210    ) -> Task<()> {
211        let credentials_provider = <dyn CredentialsProvider>::global(cx);
212        cx.spawn({
213            async move |ent, cx| {
214                let load_status =
215                    ApiKey::load_from_system_keychain_impl(&url, credentials_provider.as_ref(), cx)
216                        .await;
217                ent.update(cx, |ent, cx| {
218                    let this = get_this(ent);
219                    this.url = url;
220                    this.load_status = load_status;
221                    this.load_task = None;
222                    cx.notify();
223                })
224                .ok();
225            }
226        })
227    }
228}
229
230impl ApiKey {
231    pub fn key(&self) -> &str {
232        &self.key
233    }
234
235    pub fn from_env(env_var_name: SharedString, key: &str) -> Self {
236        Self {
237            source: ApiKeySource::EnvVar(env_var_name),
238            key: key.into(),
239        }
240    }
241
242    pub async fn load_from_system_keychain(
243        url: &str,
244        credentials_provider: &dyn CredentialsProvider,
245        cx: &AsyncApp,
246    ) -> Result<Self, AuthenticateError> {
247        Self::load_from_system_keychain_impl(url, credentials_provider, cx)
248            .await
249            .into_authenticate_result()
250    }
251
252    async fn load_from_system_keychain_impl(
253        url: &str,
254        credentials_provider: &dyn CredentialsProvider,
255        cx: &AsyncApp,
256    ) -> LoadStatus {
257        if url.is_empty() {
258            return LoadStatus::NotPresent;
259        }
260        let read_result = credentials_provider.read_credentials(&url, cx).await;
261        let api_key = match read_result {
262            Ok(Some((_, api_key))) => api_key,
263            Ok(None) => return LoadStatus::NotPresent,
264            Err(err) => return LoadStatus::Error(err.to_string()),
265        };
266        let key = match str::from_utf8(&api_key) {
267            Ok(key) => key,
268            Err(_) => return LoadStatus::Error(format!("API key for URL {url} is not utf8")),
269        };
270        LoadStatus::Loaded(Self {
271            source: ApiKeySource::SystemKeychain,
272            key: key.into(),
273        })
274    }
275}
276
277impl LoadStatus {
278    fn into_authenticate_result(self) -> Result<ApiKey, AuthenticateError> {
279        match self {
280            LoadStatus::Loaded(api_key) => Ok(api_key),
281            LoadStatus::NotPresent => Err(AuthenticateError::CredentialsNotFound),
282            LoadStatus::Error(err) => Err(AuthenticateError::Other(anyhow!(err))),
283        }
284    }
285}
286
287#[derive(Debug, Clone)]
288enum ApiKeySource {
289    EnvVar(SharedString),
290    SystemKeychain,
291}
292
293impl Display for ApiKeySource {
294    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
295        match self {
296            ApiKeySource::EnvVar(var) => write!(f, "environment variable {}", var),
297            ApiKeySource::SystemKeychain => write!(f, "system keychain"),
298        }
299    }
300}