1use crate::{
2 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
3 LanguageModelProviderState,
4};
5use collections::BTreeMap;
6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
7use std::{str::FromStr, sync::Arc};
8use thiserror::Error;
9use util::maybe;
10
11pub fn init(cx: &mut App) {
12 let registry = cx.new(|_cx| LanguageModelRegistry::default());
13 cx.set_global(GlobalLanguageModelRegistry(registry));
14}
15
16struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
17
18impl Global for GlobalLanguageModelRegistry {}
19
20#[derive(Error)]
21pub enum ConfigurationError {
22 #[error("Configure at least one LLM provider to start using the panel.")]
23 NoProvider,
24 #[error("LLM Provider is not configured or does not support the configured model.")]
25 ModelNotFound,
26 #[error("{} LLM provider is not configured.", .0.name().0)]
27 ProviderNotAuthenticated(Arc<dyn LanguageModelProvider>),
28 #[error("Using the {} LLM provider requires accepting the Terms of Service.",
29 .0.name().0)]
30 ProviderPendingTermsAcceptance(Arc<dyn LanguageModelProvider>),
31}
32
33impl std::fmt::Debug for ConfigurationError {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 match self {
36 Self::NoProvider => write!(f, "NoProvider"),
37 Self::ModelNotFound => write!(f, "ModelNotFound"),
38 Self::ProviderNotAuthenticated(provider) => {
39 write!(f, "ProviderNotAuthenticated({})", provider.id())
40 }
41 Self::ProviderPendingTermsAcceptance(provider) => {
42 write!(f, "ProviderPendingTermsAcceptance({})", provider.id())
43 }
44 }
45 }
46}
47
48#[derive(Default)]
49pub struct LanguageModelRegistry {
50 default_model: Option<ConfiguredModel>,
51 default_fast_model: Option<ConfiguredModel>,
52 inline_assistant_model: Option<ConfiguredModel>,
53 commit_message_model: Option<ConfiguredModel>,
54 thread_summary_model: Option<ConfiguredModel>,
55 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
56 inline_alternatives: Vec<Arc<dyn LanguageModel>>,
57}
58
59#[derive(Debug)]
60pub struct SelectedModel {
61 pub provider: LanguageModelProviderId,
62 pub model: LanguageModelId,
63}
64
65impl FromStr for SelectedModel {
66 type Err = String;
67
68 /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel`
69 fn from_str(id: &str) -> Result<SelectedModel, Self::Err> {
70 let parts: Vec<&str> = id.split('/').collect();
71 let [provider_id, model_id] = parts.as_slice() else {
72 return Err(format!(
73 "Invalid model identifier format: `{}`. Expected `provider_id/model_id`",
74 id
75 ));
76 };
77
78 if provider_id.is_empty() || model_id.is_empty() {
79 return Err(format!("Provider and model ids can't be empty: `{}`", id));
80 }
81
82 Ok(SelectedModel {
83 provider: LanguageModelProviderId(provider_id.to_string().into()),
84 model: LanguageModelId(model_id.to_string().into()),
85 })
86 }
87}
88
89#[derive(Clone)]
90pub struct ConfiguredModel {
91 pub provider: Arc<dyn LanguageModelProvider>,
92 pub model: Arc<dyn LanguageModel>,
93}
94
95impl ConfiguredModel {
96 pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
97 self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
98 }
99
100 pub fn is_provided_by_zed(&self) -> bool {
101 self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
102 }
103}
104
105pub enum Event {
106 DefaultModelChanged,
107 InlineAssistantModelChanged,
108 CommitMessageModelChanged,
109 ThreadSummaryModelChanged,
110 ProviderStateChanged,
111 AddedProvider(LanguageModelProviderId),
112 RemovedProvider(LanguageModelProviderId),
113}
114
115impl EventEmitter<Event> for LanguageModelRegistry {}
116
117impl LanguageModelRegistry {
118 pub fn global(cx: &App) -> Entity<Self> {
119 cx.global::<GlobalLanguageModelRegistry>().0.clone()
120 }
121
122 pub fn read_global(cx: &App) -> &Self {
123 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
124 }
125
126 #[cfg(any(test, feature = "test-support"))]
127 pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
128 let fake_provider = crate::fake_provider::FakeLanguageModelProvider::default();
129 let registry = cx.new(|cx| {
130 let mut registry = Self::default();
131 registry.register_provider(fake_provider.clone(), cx);
132 let model = fake_provider.provided_models(cx)[0].clone();
133 let configured_model = ConfiguredModel {
134 provider: Arc::new(fake_provider.clone()),
135 model,
136 };
137 registry.set_default_model(Some(configured_model), cx);
138 registry
139 });
140 cx.set_global(GlobalLanguageModelRegistry(registry));
141 fake_provider
142 }
143
144 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
145 &mut self,
146 provider: T,
147 cx: &mut Context<Self>,
148 ) {
149 let id = provider.id();
150
151 let subscription = provider.subscribe(cx, |_, cx| {
152 cx.emit(Event::ProviderStateChanged);
153 });
154 if let Some(subscription) = subscription {
155 subscription.detach();
156 }
157
158 self.providers.insert(id.clone(), Arc::new(provider));
159 cx.emit(Event::AddedProvider(id));
160 }
161
162 pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
163 if self.providers.remove(&id).is_some() {
164 cx.emit(Event::RemovedProvider(id));
165 }
166 }
167
168 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
169 let zed_provider_id = LanguageModelProviderId("zed.dev".into());
170 let mut providers = Vec::with_capacity(self.providers.len());
171 if let Some(provider) = self.providers.get(&zed_provider_id) {
172 providers.push(provider.clone());
173 }
174 providers.extend(self.providers.values().filter_map(|p| {
175 if p.id() != zed_provider_id {
176 Some(p.clone())
177 } else {
178 None
179 }
180 }));
181 providers
182 }
183
184 pub fn configuration_error(
185 &self,
186 model: Option<ConfiguredModel>,
187 cx: &App,
188 ) -> Option<ConfigurationError> {
189 let Some(model) = model else {
190 if !self.has_authenticated_provider(cx) {
191 return Some(ConfigurationError::NoProvider);
192 }
193 return Some(ConfigurationError::ModelNotFound);
194 };
195
196 if !model.provider.is_authenticated(cx) {
197 return Some(ConfigurationError::ProviderNotAuthenticated(model.provider));
198 }
199
200 if model.provider.must_accept_terms(cx) {
201 return Some(ConfigurationError::ProviderPendingTermsAcceptance(
202 model.provider,
203 ));
204 }
205
206 None
207 }
208
209 /// Returns `true` if at least one provider that is authenticated.
210 pub fn has_authenticated_provider(&self, cx: &App) -> bool {
211 self.providers.values().any(|p| p.is_authenticated(cx))
212 }
213
214 pub fn available_models<'a>(
215 &'a self,
216 cx: &'a App,
217 ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
218 self.providers
219 .values()
220 .flat_map(|provider| provider.provided_models(cx))
221 }
222
223 pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
224 self.providers.get(id).cloned()
225 }
226
227 pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
228 let configured_model = model.and_then(|model| self.select_model(model, cx));
229 self.set_default_model(configured_model, cx);
230 }
231
232 pub fn select_inline_assistant_model(
233 &mut self,
234 model: Option<&SelectedModel>,
235 cx: &mut Context<Self>,
236 ) {
237 let configured_model = model.and_then(|model| self.select_model(model, cx));
238 self.set_inline_assistant_model(configured_model, cx);
239 }
240
241 pub fn select_commit_message_model(
242 &mut self,
243 model: Option<&SelectedModel>,
244 cx: &mut Context<Self>,
245 ) {
246 let configured_model = model.and_then(|model| self.select_model(model, cx));
247 self.set_commit_message_model(configured_model, cx);
248 }
249
250 pub fn select_thread_summary_model(
251 &mut self,
252 model: Option<&SelectedModel>,
253 cx: &mut Context<Self>,
254 ) {
255 let configured_model = model.and_then(|model| self.select_model(model, cx));
256 self.set_thread_summary_model(configured_model, cx);
257 }
258
259 /// Selects and sets the inline alternatives for language models based on
260 /// provider name and id.
261 pub fn select_inline_alternative_models(
262 &mut self,
263 alternatives: impl IntoIterator<Item = SelectedModel>,
264 cx: &mut Context<Self>,
265 ) {
266 self.inline_alternatives = alternatives
267 .into_iter()
268 .flat_map(|alternative| {
269 self.select_model(&alternative, cx)
270 .map(|configured_model| configured_model.model)
271 })
272 .collect::<Vec<_>>();
273 }
274
275 pub fn select_model(
276 &mut self,
277 selected_model: &SelectedModel,
278 cx: &mut Context<Self>,
279 ) -> Option<ConfiguredModel> {
280 let provider = self.provider(&selected_model.provider)?;
281 let model = provider
282 .provided_models(cx)
283 .iter()
284 .find(|model| model.id() == selected_model.model)?
285 .clone();
286 Some(ConfiguredModel { provider, model })
287 }
288
289 pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
290 match (self.default_model.as_ref(), model.as_ref()) {
291 (Some(old), Some(new)) if old.is_same_as(new) => {}
292 (None, None) => {}
293 _ => cx.emit(Event::DefaultModelChanged),
294 }
295 self.default_fast_model = maybe!({
296 let provider = &model.as_ref()?.provider;
297 let fast_model = provider.default_fast_model(cx)?;
298 Some(ConfiguredModel {
299 provider: provider.clone(),
300 model: fast_model,
301 })
302 });
303 self.default_model = model;
304 }
305
306 pub fn set_inline_assistant_model(
307 &mut self,
308 model: Option<ConfiguredModel>,
309 cx: &mut Context<Self>,
310 ) {
311 match (self.inline_assistant_model.as_ref(), model.as_ref()) {
312 (Some(old), Some(new)) if old.is_same_as(new) => {}
313 (None, None) => {}
314 _ => cx.emit(Event::InlineAssistantModelChanged),
315 }
316 self.inline_assistant_model = model;
317 }
318
319 pub fn set_commit_message_model(
320 &mut self,
321 model: Option<ConfiguredModel>,
322 cx: &mut Context<Self>,
323 ) {
324 match (self.commit_message_model.as_ref(), model.as_ref()) {
325 (Some(old), Some(new)) if old.is_same_as(new) => {}
326 (None, None) => {}
327 _ => cx.emit(Event::CommitMessageModelChanged),
328 }
329 self.commit_message_model = model;
330 }
331
332 pub fn set_thread_summary_model(
333 &mut self,
334 model: Option<ConfiguredModel>,
335 cx: &mut Context<Self>,
336 ) {
337 match (self.thread_summary_model.as_ref(), model.as_ref()) {
338 (Some(old), Some(new)) if old.is_same_as(new) => {}
339 (None, None) => {}
340 _ => cx.emit(Event::ThreadSummaryModelChanged),
341 }
342 self.thread_summary_model = model;
343 }
344
345 pub fn default_model(&self) -> Option<ConfiguredModel> {
346 #[cfg(debug_assertions)]
347 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
348 return None;
349 }
350
351 self.default_model.clone()
352 }
353
354 pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
355 #[cfg(debug_assertions)]
356 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
357 return None;
358 }
359
360 self.inline_assistant_model
361 .clone()
362 .or_else(|| self.default_model.clone())
363 }
364
365 pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
366 #[cfg(debug_assertions)]
367 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
368 return None;
369 }
370
371 self.commit_message_model
372 .clone()
373 .or_else(|| self.default_fast_model.clone())
374 .or_else(|| self.default_model.clone())
375 }
376
377 pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
378 #[cfg(debug_assertions)]
379 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
380 return None;
381 }
382
383 self.thread_summary_model
384 .clone()
385 .or_else(|| self.default_fast_model.clone())
386 .or_else(|| self.default_model.clone())
387 }
388
389 /// The models to use for inline assists. Returns the union of the active
390 /// model and all inline alternatives. When there are multiple models, the
391 /// user will be able to cycle through results.
392 pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
393 &self.inline_alternatives
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400 use crate::fake_provider::FakeLanguageModelProvider;
401
402 #[gpui::test]
403 fn test_register_providers(cx: &mut App) {
404 let registry = cx.new(|_| LanguageModelRegistry::default());
405
406 let provider = FakeLanguageModelProvider::default();
407 registry.update(cx, |registry, cx| {
408 registry.register_provider(provider.clone(), cx);
409 });
410
411 let providers = registry.read(cx).providers();
412 assert_eq!(providers.len(), 1);
413 assert_eq!(providers[0].id(), provider.id());
414
415 registry.update(cx, |registry, cx| {
416 registry.unregister_provider(provider.id(), cx);
417 });
418
419 let providers = registry.read(cx).providers();
420 assert!(providers.is_empty());
421 }
422}