1use anyhow::{Result, anyhow};
2use credentials_provider::CredentialsProvider;
3use env_var::EnvVar;
4use futures::{FutureExt, future};
5use gpui::{AsyncApp, Context, SharedString, Task};
6use std::{
7 fmt::{Display, Formatter},
8 sync::Arc,
9};
10use util::ResultExt as _;
11
12use crate::AuthenticateError;
13
14/// Manages a single API key for a language model provider. API keys either come from environment
15/// variables or the system keychain.
16///
17/// Keys from the system keychain are associated with a provider URL, and this ensures that they are
18/// only used with that URL.
19pub struct ApiKeyState {
20 pub url: SharedString,
21 env_var: EnvVar,
22 load_status: LoadStatus,
23 load_task: Option<future::Shared<Task<()>>>,
24}
25
26#[derive(Debug, Clone)]
27pub enum LoadStatus {
28 NotPresent,
29 Error(String),
30 Loaded(ApiKey),
31}
32
33#[derive(Debug, Clone)]
34pub struct ApiKey {
35 source: ApiKeySource,
36 key: Arc<str>,
37}
38
39impl ApiKeyState {
40 pub fn new(url: SharedString, env_var: EnvVar) -> Self {
41 Self {
42 url,
43 env_var,
44 load_status: LoadStatus::NotPresent,
45 load_task: None,
46 }
47 }
48
49 pub fn has_key(&self) -> bool {
50 matches!(self.load_status, LoadStatus::Loaded { .. })
51 }
52
53 pub fn env_var_name(&self) -> &SharedString {
54 &self.env_var.name
55 }
56
57 pub fn is_from_env_var(&self) -> bool {
58 match &self.load_status {
59 LoadStatus::Loaded(ApiKey {
60 source: ApiKeySource::EnvVar { .. },
61 ..
62 }) => true,
63 _ => false,
64 }
65 }
66
67 /// Get the stored API key, verifying that it is associated with the URL. Returns `None` if
68 /// there is no key or for URL mismatches, and the mismatch case is logged.
69 ///
70 /// To avoid URL mismatches, expects that `load_if_needed` or `handle_url_change` has been
71 /// called with this URL.
72 pub fn key(&self, url: &str) -> Option<Arc<str>> {
73 let api_key = match &self.load_status {
74 LoadStatus::Loaded(api_key) => api_key,
75 _ => return None,
76 };
77 if url == self.url.as_str() {
78 Some(api_key.key.clone())
79 } else if let ApiKeySource::EnvVar(var_name) = &api_key.source {
80 log::warn!(
81 "{} is now being used with URL {}, when initially it was used with URL {}",
82 var_name,
83 url,
84 self.url
85 );
86 Some(api_key.key.clone())
87 } else {
88 // bug case because load_if_needed should be called whenever the url may have changed
89 log::error!(
90 "bug: Attempted to use API key associated with URL {} instead with URL {}",
91 self.url,
92 url
93 );
94 None
95 }
96 }
97
98 /// Set or delete the API key in the system keychain.
99 pub fn store<Ent: 'static>(
100 &mut self,
101 url: SharedString,
102 key: Option<String>,
103 get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
104 provider: Arc<dyn CredentialsProvider>,
105 cx: &Context<Ent>,
106 ) -> Task<Result<()>> {
107 if self.is_from_env_var() {
108 return Task::ready(Err(anyhow!(
109 "bug: attempted to store API key in system keychain when API key is from env var",
110 )));
111 }
112 cx.spawn(async move |ent, cx| {
113 if let Some(key) = &key {
114 provider
115 .write_credentials(&url, "Bearer", key.as_bytes(), cx)
116 .await
117 .log_err();
118 } else {
119 provider.delete_credentials(&url, cx).await.log_err();
120 }
121 ent.update(cx, |ent, cx| {
122 let this = get_this(ent);
123 this.url = url;
124 this.load_status = match &key {
125 Some(key) => LoadStatus::Loaded(ApiKey {
126 source: ApiKeySource::SystemKeychain,
127 key: key.as_str().into(),
128 }),
129 None => LoadStatus::NotPresent,
130 };
131 cx.notify();
132 })
133 })
134 }
135
136 /// Reloads the API key if the current API key is associated with a different URL.
137 ///
138 /// Note that it is not efficient to use this or `load_if_needed` with multiple URLs
139 /// interchangeably - URL change should correspond to some user initiated change.
140 pub fn handle_url_change<Ent: 'static>(
141 &mut self,
142 url: SharedString,
143 get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
144 provider: Arc<dyn CredentialsProvider>,
145 cx: &mut Context<Ent>,
146 ) {
147 if url != self.url {
148 if !self.is_from_env_var() {
149 // loading will continue even though this result task is dropped
150 let _task = self.load_if_needed(url, get_this, provider, cx);
151 }
152 }
153 }
154
155 /// If needed, loads the API key associated with the given URL from the system keychain. When a
156 /// non-empty environment variable is provided, it will be used instead. If called when an API
157 /// key was already loaded for a different URL, that key will be cleared before loading.
158 ///
159 /// Dropping the returned Task does not cancel key loading.
160 pub fn load_if_needed<Ent: 'static>(
161 &mut self,
162 url: SharedString,
163 get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
164 provider: Arc<dyn CredentialsProvider>,
165 cx: &mut Context<Ent>,
166 ) -> Task<Result<(), AuthenticateError>> {
167 if let LoadStatus::Loaded { .. } = &self.load_status
168 && self.url == url
169 {
170 return Task::ready(Ok(()));
171 }
172
173 if let Some(key) = &self.env_var.value
174 && !key.is_empty()
175 {
176 let api_key = ApiKey::from_env(self.env_var.name.clone(), key);
177 self.url = url;
178 self.load_status = LoadStatus::Loaded(api_key);
179 self.load_task = None;
180 cx.notify();
181 return Task::ready(Ok(()));
182 }
183
184 let task = if let Some(load_task) = &self.load_task {
185 load_task.clone()
186 } else {
187 let load_task = Self::load(url.clone(), get_this.clone(), provider, cx).shared();
188 self.url = url;
189 self.load_status = LoadStatus::NotPresent;
190 self.load_task = Some(load_task.clone());
191 cx.notify();
192 load_task
193 };
194
195 cx.spawn(async move |ent, cx| {
196 task.await;
197 ent.update(cx, |ent, _cx| {
198 get_this(ent).load_status.clone().into_authenticate_result()
199 })
200 .ok();
201 Ok(())
202 })
203 }
204
205 fn load<Ent: 'static>(
206 url: SharedString,
207 get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
208 provider: Arc<dyn CredentialsProvider>,
209 cx: &Context<Ent>,
210 ) -> Task<()> {
211 cx.spawn({
212 async move |ent, cx| {
213 let load_status =
214 ApiKey::load_from_system_keychain_impl(&url, provider.as_ref(), cx).await;
215 ent.update(cx, |ent, cx| {
216 let this = get_this(ent);
217 this.url = url;
218 this.load_status = load_status;
219 this.load_task = None;
220 cx.notify();
221 })
222 .ok();
223 }
224 })
225 }
226}
227
228impl ApiKey {
229 pub fn key(&self) -> &str {
230 &self.key
231 }
232
233 pub fn from_env(env_var_name: SharedString, key: &str) -> Self {
234 Self {
235 source: ApiKeySource::EnvVar(env_var_name),
236 key: key.into(),
237 }
238 }
239
240 pub async fn load_from_system_keychain(
241 url: &str,
242 credentials_provider: &dyn CredentialsProvider,
243 cx: &AsyncApp,
244 ) -> Result<Self, AuthenticateError> {
245 Self::load_from_system_keychain_impl(url, credentials_provider, cx)
246 .await
247 .into_authenticate_result()
248 }
249
250 async fn load_from_system_keychain_impl(
251 url: &str,
252 credentials_provider: &dyn CredentialsProvider,
253 cx: &AsyncApp,
254 ) -> LoadStatus {
255 if url.is_empty() {
256 return LoadStatus::NotPresent;
257 }
258 let read_result = credentials_provider.read_credentials(&url, cx).await;
259 let api_key = match read_result {
260 Ok(Some((_, api_key))) => api_key,
261 Ok(None) => return LoadStatus::NotPresent,
262 Err(err) => return LoadStatus::Error(err.to_string()),
263 };
264 let key = match str::from_utf8(&api_key) {
265 Ok(key) => key,
266 Err(_) => return LoadStatus::Error(format!("API key for URL {url} is not utf8")),
267 };
268 LoadStatus::Loaded(Self {
269 source: ApiKeySource::SystemKeychain,
270 key: key.into(),
271 })
272 }
273}
274
275impl LoadStatus {
276 fn into_authenticate_result(self) -> Result<ApiKey, AuthenticateError> {
277 match self {
278 LoadStatus::Loaded(api_key) => Ok(api_key),
279 LoadStatus::NotPresent => Err(AuthenticateError::CredentialsNotFound),
280 LoadStatus::Error(err) => Err(AuthenticateError::Other(anyhow!(err))),
281 }
282 }
283}
284
285#[derive(Debug, Clone)]
286enum ApiKeySource {
287 EnvVar(SharedString),
288 SystemKeychain,
289}
290
291impl Display for ApiKeySource {
292 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
293 match self {
294 ApiKeySource::EnvVar(var) => write!(f, "environment variable {}", var),
295 ApiKeySource::SystemKeychain => write!(f, "system keychain"),
296 }
297 }
298}