1use anyhow::{Result, anyhow};
2use credentials_provider::CredentialsProvider;
3use futures::{FutureExt, future};
4use gpui::{AsyncApp, Context, SharedString, Task};
5use language_model::AuthenticateError;
6use std::{
7 fmt::{Display, Formatter},
8 sync::Arc,
9};
10use util::ResultExt as _;
11use zed_env_vars::EnvVar;
12
13/// Manages a single API key for a language model provider. API keys either come from environment
14/// variables or the system keychain.
15///
16/// Keys from the system keychain are associated with a provider URL, and this ensures that they are
17/// only used with that URL.
18pub struct ApiKeyState {
19 url: SharedString,
20 load_status: LoadStatus,
21 load_task: Option<future::Shared<Task<()>>>,
22}
23
24#[derive(Debug, Clone)]
25pub enum LoadStatus {
26 NotPresent,
27 Error(String),
28 Loaded(ApiKey),
29}
30
31#[derive(Debug, Clone)]
32pub struct ApiKey {
33 source: ApiKeySource,
34 key: Arc<str>,
35}
36
37impl ApiKeyState {
38 pub fn new(url: SharedString) -> Self {
39 Self {
40 url,
41 load_status: LoadStatus::NotPresent,
42 load_task: None,
43 }
44 }
45
46 pub fn has_key(&self) -> bool {
47 matches!(self.load_status, LoadStatus::Loaded { .. })
48 }
49
50 pub fn is_from_env_var(&self) -> bool {
51 match &self.load_status {
52 LoadStatus::Loaded(ApiKey {
53 source: ApiKeySource::EnvVar { .. },
54 ..
55 }) => true,
56 _ => false,
57 }
58 }
59
60 /// Get the stored API key, verifying that it is associated with the URL. Returns `None` if
61 /// there is no key or for URL mismatches, and the mismatch case is logged.
62 ///
63 /// To avoid URL mismatches, expects that `load_if_needed` or `handle_url_change` has been
64 /// called with this URL.
65 pub fn key(&self, url: &str) -> Option<Arc<str>> {
66 let api_key = match &self.load_status {
67 LoadStatus::Loaded(api_key) => api_key,
68 _ => return None,
69 };
70 if url == self.url.as_str() {
71 Some(api_key.key.clone())
72 } else if let ApiKeySource::EnvVar(var_name) = &api_key.source {
73 log::warn!(
74 "{} is now being used with URL {}, when initially it was used with URL {}",
75 var_name,
76 url,
77 self.url
78 );
79 Some(api_key.key.clone())
80 } else {
81 // bug case because load_if_needed should be called whenever the url may have changed
82 log::error!(
83 "bug: Attempted to use API key associated with URL {} instead with URL {}",
84 self.url,
85 url
86 );
87 None
88 }
89 }
90
91 /// Set or delete the API key in the system keychain.
92 pub fn store<Ent: 'static>(
93 &mut self,
94 url: SharedString,
95 key: Option<String>,
96 get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
97 cx: &Context<Ent>,
98 ) -> Task<Result<()>> {
99 if self.is_from_env_var() {
100 return Task::ready(Err(anyhow!(
101 "bug: attempted to store API key in system keychain when API key is from env var",
102 )));
103 }
104 let credentials_provider = <dyn CredentialsProvider>::global(cx);
105 cx.spawn(async move |ent, cx| {
106 if let Some(key) = &key {
107 credentials_provider
108 .write_credentials(&url, "Bearer", key.as_bytes(), cx)
109 .await
110 .log_err();
111 } else {
112 credentials_provider
113 .delete_credentials(&url, cx)
114 .await
115 .log_err();
116 }
117 ent.update(cx, |ent, cx| {
118 let this = get_this(ent);
119 this.url = url;
120 this.load_status = match &key {
121 Some(key) => LoadStatus::Loaded(ApiKey {
122 source: ApiKeySource::SystemKeychain,
123 key: key.as_str().into(),
124 }),
125 None => LoadStatus::NotPresent,
126 };
127 cx.notify();
128 })
129 })
130 }
131
132 /// Reloads the API key if the current API key is associated with a different URL.
133 ///
134 /// Note that it is not efficient to use this or `load_if_needed` with multiple URLs
135 /// interchangeably - URL change should correspond to some user initiated change.
136 pub fn handle_url_change<Ent: 'static>(
137 &mut self,
138 url: SharedString,
139 env_var: &EnvVar,
140 get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
141 cx: &mut Context<Ent>,
142 ) {
143 if url != self.url {
144 if !self.is_from_env_var() {
145 // loading will continue even though this result task is dropped
146 let _task = self.load_if_needed(url, env_var, get_this, cx);
147 }
148 }
149 }
150
151 /// If needed, loads the API key associated with the given URL from the system keychain. When a
152 /// non-empty environment variable is provided, it will be used instead. If called when an API
153 /// key was already loaded for a different URL, that key will be cleared before loading.
154 ///
155 /// Dropping the returned Task does not cancel key loading.
156 pub fn load_if_needed<Ent: 'static>(
157 &mut self,
158 url: SharedString,
159 env_var: &EnvVar,
160 get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
161 cx: &mut Context<Ent>,
162 ) -> Task<Result<(), AuthenticateError>> {
163 if let LoadStatus::Loaded { .. } = &self.load_status
164 && self.url == url
165 {
166 return Task::ready(Ok(()));
167 }
168
169 if let Some(key) = &env_var.value
170 && !key.is_empty()
171 {
172 let api_key = ApiKey::from_env(env_var.name.clone(), key);
173 self.url = url;
174 self.load_status = LoadStatus::Loaded(api_key);
175 self.load_task = None;
176 cx.notify();
177 return Task::ready(Ok(()));
178 }
179
180 let task = if let Some(load_task) = &self.load_task {
181 load_task.clone()
182 } else {
183 let load_task = Self::load(url.clone(), get_this.clone(), cx).shared();
184 self.url = url;
185 self.load_status = LoadStatus::NotPresent;
186 self.load_task = Some(load_task.clone());
187 cx.notify();
188 load_task
189 };
190
191 cx.spawn(async move |ent, cx| {
192 task.await;
193 ent.update(cx, |ent, _cx| {
194 get_this(ent).load_status.clone().into_authenticate_result()
195 })
196 .ok();
197 Ok(())
198 })
199 }
200
201 fn load<Ent: 'static>(
202 url: SharedString,
203 get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
204 cx: &Context<Ent>,
205 ) -> Task<()> {
206 let credentials_provider = <dyn CredentialsProvider>::global(cx);
207 cx.spawn({
208 async move |ent, cx| {
209 let load_status =
210 ApiKey::load_from_system_keychain_impl(&url, credentials_provider.as_ref(), cx)
211 .await;
212 ent.update(cx, |ent, cx| {
213 let this = get_this(ent);
214 this.url = url;
215 this.load_status = load_status;
216 this.load_task = None;
217 cx.notify();
218 })
219 .ok();
220 }
221 })
222 }
223}
224
225impl ApiKey {
226 pub fn key(&self) -> &str {
227 &self.key
228 }
229
230 pub fn from_env(env_var_name: SharedString, key: &str) -> Self {
231 Self {
232 source: ApiKeySource::EnvVar(env_var_name),
233 key: key.into(),
234 }
235 }
236
237 pub async fn load_from_system_keychain(
238 url: &str,
239 credentials_provider: &dyn CredentialsProvider,
240 cx: &AsyncApp,
241 ) -> Result<Self, AuthenticateError> {
242 Self::load_from_system_keychain_impl(url, credentials_provider, cx)
243 .await
244 .into_authenticate_result()
245 }
246
247 async fn load_from_system_keychain_impl(
248 url: &str,
249 credentials_provider: &dyn CredentialsProvider,
250 cx: &AsyncApp,
251 ) -> LoadStatus {
252 if url.is_empty() {
253 return LoadStatus::NotPresent;
254 }
255 let read_result = credentials_provider.read_credentials(&url, cx).await;
256 let api_key = match read_result {
257 Ok(Some((_, api_key))) => api_key,
258 Ok(None) => return LoadStatus::NotPresent,
259 Err(err) => return LoadStatus::Error(err.to_string()),
260 };
261 let key = match str::from_utf8(&api_key) {
262 Ok(key) => key,
263 Err(_) => return LoadStatus::Error(format!("API key for URL {url} is not utf8")),
264 };
265 LoadStatus::Loaded(Self {
266 source: ApiKeySource::SystemKeychain,
267 key: key.into(),
268 })
269 }
270}
271
272impl LoadStatus {
273 fn into_authenticate_result(self) -> Result<ApiKey, AuthenticateError> {
274 match self {
275 LoadStatus::Loaded(api_key) => Ok(api_key),
276 LoadStatus::NotPresent => Err(AuthenticateError::CredentialsNotFound),
277 LoadStatus::Error(err) => Err(AuthenticateError::Other(anyhow!(err))),
278 }
279 }
280}
281
282#[derive(Debug, Clone)]
283enum ApiKeySource {
284 EnvVar(SharedString),
285 SystemKeychain,
286}
287
288impl Display for ApiKeySource {
289 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
290 match self {
291 ApiKeySource::EnvVar(var) => write!(f, "environment variable {}", var),
292 ApiKeySource::SystemKeychain => write!(f, "system keychain"),
293 }
294 }
295}