1use anyhow::{Result, anyhow};
2use credentials_provider::CredentialsProvider;
3use futures::{FutureExt, future};
4use gpui::{AsyncApp, Context, SharedString, Task};
5use std::{
6 fmt::{Display, Formatter},
7 sync::Arc,
8};
9use util::ResultExt as _;
10use zed_env_vars::EnvVar;
11
12use crate::AuthenticateError;
13
14/// Manages a single API key for a language model provider. API keys either come from environment
15/// variables or the system keychain.
16///
17/// Keys from the system keychain are associated with a provider URL, and this ensures that they are
18/// only used with that URL.
19pub struct ApiKeyState {
20 pub url: SharedString,
21 env_var: EnvVar,
22 load_status: LoadStatus,
23 load_task: Option<future::Shared<Task<()>>>,
24}
25
26#[derive(Debug, Clone)]
27pub enum LoadStatus {
28 NotPresent,
29 Error(String),
30 Loaded(ApiKey),
31}
32
33#[derive(Debug, Clone)]
34pub struct ApiKey {
35 source: ApiKeySource,
36 key: Arc<str>,
37}
38
39impl ApiKeyState {
40 pub fn new(url: SharedString, env_var: EnvVar) -> Self {
41 Self {
42 url,
43 env_var,
44 load_status: LoadStatus::NotPresent,
45 load_task: None,
46 }
47 }
48
49 pub fn has_key(&self) -> bool {
50 matches!(self.load_status, LoadStatus::Loaded { .. })
51 }
52
53 pub fn env_var_name(&self) -> &SharedString {
54 &self.env_var.name
55 }
56
57 pub fn is_from_env_var(&self) -> bool {
58 match &self.load_status {
59 LoadStatus::Loaded(ApiKey {
60 source: ApiKeySource::EnvVar { .. },
61 ..
62 }) => true,
63 _ => false,
64 }
65 }
66
67 /// Get the stored API key, verifying that it is associated with the URL. Returns `None` if
68 /// there is no key or for URL mismatches, and the mismatch case is logged.
69 ///
70 /// To avoid URL mismatches, expects that `load_if_needed` or `handle_url_change` has been
71 /// called with this URL.
72 pub fn key(&self, url: &str) -> Option<Arc<str>> {
73 let api_key = match &self.load_status {
74 LoadStatus::Loaded(api_key) => api_key,
75 _ => return None,
76 };
77 if url == self.url.as_str() {
78 Some(api_key.key.clone())
79 } else if let ApiKeySource::EnvVar(var_name) = &api_key.source {
80 log::warn!(
81 "{} is now being used with URL {}, when initially it was used with URL {}",
82 var_name,
83 url,
84 self.url
85 );
86 Some(api_key.key.clone())
87 } else {
88 // bug case because load_if_needed should be called whenever the url may have changed
89 log::error!(
90 "bug: Attempted to use API key associated with URL {} instead with URL {}",
91 self.url,
92 url
93 );
94 None
95 }
96 }
97
98 /// Set or delete the API key in the system keychain.
99 pub fn store<Ent: 'static>(
100 &mut self,
101 url: SharedString,
102 key: Option<String>,
103 get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
104 cx: &Context<Ent>,
105 ) -> Task<Result<()>> {
106 if self.is_from_env_var() {
107 return Task::ready(Err(anyhow!(
108 "bug: attempted to store API key in system keychain when API key is from env var",
109 )));
110 }
111 let credentials_provider = <dyn CredentialsProvider>::global(cx);
112 cx.spawn(async move |ent, cx| {
113 if let Some(key) = &key {
114 credentials_provider
115 .write_credentials(&url, "Bearer", key.as_bytes(), cx)
116 .await
117 .log_err();
118 } else {
119 credentials_provider
120 .delete_credentials(&url, cx)
121 .await
122 .log_err();
123 }
124 ent.update(cx, |ent, cx| {
125 let this = get_this(ent);
126 this.url = url;
127 this.load_status = match &key {
128 Some(key) => LoadStatus::Loaded(ApiKey {
129 source: ApiKeySource::SystemKeychain,
130 key: key.as_str().into(),
131 }),
132 None => LoadStatus::NotPresent,
133 };
134 cx.notify();
135 })
136 })
137 }
138
139 /// Reloads the API key if the current API key is associated with a different URL.
140 ///
141 /// Note that it is not efficient to use this or `load_if_needed` with multiple URLs
142 /// interchangeably - URL change should correspond to some user initiated change.
143 pub fn handle_url_change<Ent: 'static>(
144 &mut self,
145 url: SharedString,
146 get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
147 cx: &mut Context<Ent>,
148 ) {
149 if url != self.url {
150 if !self.is_from_env_var() {
151 // loading will continue even though this result task is dropped
152 let _task = self.load_if_needed(url, get_this, cx);
153 }
154 }
155 }
156
157 /// If needed, loads the API key associated with the given URL from the system keychain. When a
158 /// non-empty environment variable is provided, it will be used instead. If called when an API
159 /// key was already loaded for a different URL, that key will be cleared before loading.
160 ///
161 /// Dropping the returned Task does not cancel key loading.
162 pub fn load_if_needed<Ent: 'static>(
163 &mut self,
164 url: SharedString,
165 get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
166 cx: &mut Context<Ent>,
167 ) -> Task<Result<(), AuthenticateError>> {
168 if let LoadStatus::Loaded { .. } = &self.load_status
169 && self.url == url
170 {
171 return Task::ready(Ok(()));
172 }
173
174 if let Some(key) = &self.env_var.value
175 && !key.is_empty()
176 {
177 let api_key = ApiKey::from_env(self.env_var.name.clone(), key);
178 self.url = url;
179 self.load_status = LoadStatus::Loaded(api_key);
180 self.load_task = None;
181 cx.notify();
182 return Task::ready(Ok(()));
183 }
184
185 let task = if let Some(load_task) = &self.load_task {
186 load_task.clone()
187 } else {
188 let load_task = Self::load(url.clone(), get_this.clone(), cx).shared();
189 self.url = url;
190 self.load_status = LoadStatus::NotPresent;
191 self.load_task = Some(load_task.clone());
192 cx.notify();
193 load_task
194 };
195
196 cx.spawn(async move |ent, cx| {
197 task.await;
198 ent.update(cx, |ent, _cx| {
199 get_this(ent).load_status.clone().into_authenticate_result()
200 })
201 .ok();
202 Ok(())
203 })
204 }
205
206 fn load<Ent: 'static>(
207 url: SharedString,
208 get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
209 cx: &Context<Ent>,
210 ) -> Task<()> {
211 let credentials_provider = <dyn CredentialsProvider>::global(cx);
212 cx.spawn({
213 async move |ent, cx| {
214 let load_status =
215 ApiKey::load_from_system_keychain_impl(&url, credentials_provider.as_ref(), cx)
216 .await;
217 ent.update(cx, |ent, cx| {
218 let this = get_this(ent);
219 this.url = url;
220 this.load_status = load_status;
221 this.load_task = None;
222 cx.notify();
223 })
224 .ok();
225 }
226 })
227 }
228}
229
230impl ApiKey {
231 pub fn key(&self) -> &str {
232 &self.key
233 }
234
235 pub fn from_env(env_var_name: SharedString, key: &str) -> Self {
236 Self {
237 source: ApiKeySource::EnvVar(env_var_name),
238 key: key.into(),
239 }
240 }
241
242 pub async fn load_from_system_keychain(
243 url: &str,
244 credentials_provider: &dyn CredentialsProvider,
245 cx: &AsyncApp,
246 ) -> Result<Self, AuthenticateError> {
247 Self::load_from_system_keychain_impl(url, credentials_provider, cx)
248 .await
249 .into_authenticate_result()
250 }
251
252 async fn load_from_system_keychain_impl(
253 url: &str,
254 credentials_provider: &dyn CredentialsProvider,
255 cx: &AsyncApp,
256 ) -> LoadStatus {
257 if url.is_empty() {
258 return LoadStatus::NotPresent;
259 }
260 let read_result = credentials_provider.read_credentials(&url, cx).await;
261 let api_key = match read_result {
262 Ok(Some((_, api_key))) => api_key,
263 Ok(None) => return LoadStatus::NotPresent,
264 Err(err) => return LoadStatus::Error(err.to_string()),
265 };
266 let key = match str::from_utf8(&api_key) {
267 Ok(key) => key,
268 Err(_) => return LoadStatus::Error(format!("API key for URL {url} is not utf8")),
269 };
270 LoadStatus::Loaded(Self {
271 source: ApiKeySource::SystemKeychain,
272 key: key.into(),
273 })
274 }
275}
276
277impl LoadStatus {
278 fn into_authenticate_result(self) -> Result<ApiKey, AuthenticateError> {
279 match self {
280 LoadStatus::Loaded(api_key) => Ok(api_key),
281 LoadStatus::NotPresent => Err(AuthenticateError::CredentialsNotFound),
282 LoadStatus::Error(err) => Err(AuthenticateError::Other(anyhow!(err))),
283 }
284 }
285}
286
287#[derive(Debug, Clone)]
288enum ApiKeySource {
289 EnvVar(SharedString),
290 SystemKeychain,
291}
292
293impl Display for ApiKeySource {
294 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
295 match self {
296 ApiKeySource::EnvVar(var) => write!(f, "environment variable {}", var),
297 ApiKeySource::SystemKeychain => write!(f, "system keychain"),
298 }
299 }
300}