1use crate::{
2 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
3 LanguageModelProviderState,
4};
5use collections::BTreeMap;
6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
7use std::{str::FromStr, sync::Arc};
8use thiserror::Error;
9use util::maybe;
10
11pub fn init(cx: &mut App) {
12 let registry = cx.new(|_cx| LanguageModelRegistry::default());
13 cx.set_global(GlobalLanguageModelRegistry(registry));
14}
15
16struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
17
18impl Global for GlobalLanguageModelRegistry {}
19
20#[derive(Error)]
21pub enum ConfigurationError {
22 #[error("Configure at least one LLM provider to start using the panel.")]
23 NoProvider,
24 #[error("LLM provider is not configured or does not support the configured model.")]
25 ModelNotFound,
26 #[error("{} LLM provider is not configured.", .0.name().0)]
27 ProviderNotAuthenticated(Arc<dyn LanguageModelProvider>),
28 #[error("Using the {} LLM provider requires accepting the Terms of Service.",
29 .0.name().0)]
30 ProviderPendingTermsAcceptance(Arc<dyn LanguageModelProvider>),
31}
32
33impl std::fmt::Debug for ConfigurationError {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 match self {
36 Self::NoProvider => write!(f, "NoProvider"),
37 Self::ModelNotFound => write!(f, "ModelNotFound"),
38 Self::ProviderNotAuthenticated(provider) => {
39 write!(f, "ProviderNotAuthenticated({})", provider.id())
40 }
41 Self::ProviderPendingTermsAcceptance(provider) => {
42 write!(f, "ProviderPendingTermsAcceptance({})", provider.id())
43 }
44 }
45 }
46}
47
48#[derive(Default)]
49pub struct LanguageModelRegistry {
50 default_model: Option<ConfiguredModel>,
51 default_fast_model: Option<ConfiguredModel>,
52 inline_assistant_model: Option<ConfiguredModel>,
53 commit_message_model: Option<ConfiguredModel>,
54 thread_summary_model: Option<ConfiguredModel>,
55 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
56 inline_alternatives: Vec<Arc<dyn LanguageModel>>,
57}
58
59#[derive(Debug)]
60pub struct SelectedModel {
61 pub provider: LanguageModelProviderId,
62 pub model: LanguageModelId,
63}
64
65impl FromStr for SelectedModel {
66 type Err = String;
67
68 /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel`
69 fn from_str(id: &str) -> Result<SelectedModel, Self::Err> {
70 let parts: Vec<&str> = id.split('/').collect();
71 let [provider_id, model_id] = parts.as_slice() else {
72 return Err(format!(
73 "Invalid model identifier format: `{}`. Expected `provider_id/model_id`",
74 id
75 ));
76 };
77
78 if provider_id.is_empty() || model_id.is_empty() {
79 return Err(format!("Provider and model ids can't be empty: `{}`", id));
80 }
81
82 Ok(SelectedModel {
83 provider: LanguageModelProviderId(provider_id.to_string().into()),
84 model: LanguageModelId(model_id.to_string().into()),
85 })
86 }
87}
88
89#[derive(Clone)]
90pub struct ConfiguredModel {
91 pub provider: Arc<dyn LanguageModelProvider>,
92 pub model: Arc<dyn LanguageModel>,
93}
94
95impl ConfiguredModel {
96 pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
97 self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
98 }
99
100 pub fn is_provided_by_zed(&self) -> bool {
101 self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
102 }
103}
104
105pub enum Event {
106 DefaultModelChanged,
107 InlineAssistantModelChanged,
108 CommitMessageModelChanged,
109 ThreadSummaryModelChanged,
110 ProviderStateChanged(LanguageModelProviderId),
111 AddedProvider(LanguageModelProviderId),
112 RemovedProvider(LanguageModelProviderId),
113}
114
115impl EventEmitter<Event> for LanguageModelRegistry {}
116
117impl LanguageModelRegistry {
118 pub fn global(cx: &App) -> Entity<Self> {
119 cx.global::<GlobalLanguageModelRegistry>().0.clone()
120 }
121
122 pub fn read_global(cx: &App) -> &Self {
123 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
124 }
125
126 #[cfg(any(test, feature = "test-support"))]
127 pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
128 let fake_provider = crate::fake_provider::FakeLanguageModelProvider::default();
129 let registry = cx.new(|cx| {
130 let mut registry = Self::default();
131 registry.register_provider(fake_provider.clone(), cx);
132 let model = fake_provider.provided_models(cx)[0].clone();
133 let configured_model = ConfiguredModel {
134 provider: Arc::new(fake_provider.clone()),
135 model,
136 };
137 registry.set_default_model(Some(configured_model), cx);
138 registry
139 });
140 cx.set_global(GlobalLanguageModelRegistry(registry));
141 fake_provider
142 }
143
144 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
145 &mut self,
146 provider: T,
147 cx: &mut Context<Self>,
148 ) {
149 let id = provider.id();
150
151 let subscription = provider.subscribe(cx, {
152 let id = id.clone();
153 move |_, cx| {
154 cx.emit(Event::ProviderStateChanged(id.clone()));
155 }
156 });
157 if let Some(subscription) = subscription {
158 subscription.detach();
159 }
160
161 self.providers.insert(id.clone(), Arc::new(provider));
162 cx.emit(Event::AddedProvider(id));
163 }
164
165 pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
166 if self.providers.remove(&id).is_some() {
167 cx.emit(Event::RemovedProvider(id));
168 }
169 }
170
171 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
172 let zed_provider_id = LanguageModelProviderId("zed.dev".into());
173 let mut providers = Vec::with_capacity(self.providers.len());
174 if let Some(provider) = self.providers.get(&zed_provider_id) {
175 providers.push(provider.clone());
176 }
177 providers.extend(self.providers.values().filter_map(|p| {
178 if p.id() != zed_provider_id {
179 Some(p.clone())
180 } else {
181 None
182 }
183 }));
184 providers
185 }
186
187 pub fn configuration_error(
188 &self,
189 model: Option<ConfiguredModel>,
190 cx: &App,
191 ) -> Option<ConfigurationError> {
192 let Some(model) = model else {
193 if !self.has_authenticated_provider(cx) {
194 return Some(ConfigurationError::NoProvider);
195 }
196 return Some(ConfigurationError::ModelNotFound);
197 };
198
199 if !model.provider.is_authenticated(cx) {
200 return Some(ConfigurationError::ProviderNotAuthenticated(model.provider));
201 }
202
203 if model.provider.must_accept_terms(cx) {
204 return Some(ConfigurationError::ProviderPendingTermsAcceptance(
205 model.provider,
206 ));
207 }
208
209 None
210 }
211
212 /// Returns `true` if at least one provider that is authenticated.
213 pub fn has_authenticated_provider(&self, cx: &App) -> bool {
214 self.providers.values().any(|p| p.is_authenticated(cx))
215 }
216
217 pub fn available_models<'a>(
218 &'a self,
219 cx: &'a App,
220 ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
221 self.providers
222 .values()
223 .flat_map(|provider| provider.provided_models(cx))
224 }
225
226 pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
227 self.providers.get(id).cloned()
228 }
229
230 pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
231 let configured_model = model.and_then(|model| self.select_model(model, cx));
232 self.set_default_model(configured_model, cx);
233 }
234
235 pub fn select_inline_assistant_model(
236 &mut self,
237 model: Option<&SelectedModel>,
238 cx: &mut Context<Self>,
239 ) {
240 let configured_model = model.and_then(|model| self.select_model(model, cx));
241 self.set_inline_assistant_model(configured_model, cx);
242 }
243
244 pub fn select_commit_message_model(
245 &mut self,
246 model: Option<&SelectedModel>,
247 cx: &mut Context<Self>,
248 ) {
249 let configured_model = model.and_then(|model| self.select_model(model, cx));
250 self.set_commit_message_model(configured_model, cx);
251 }
252
253 pub fn select_thread_summary_model(
254 &mut self,
255 model: Option<&SelectedModel>,
256 cx: &mut Context<Self>,
257 ) {
258 let configured_model = model.and_then(|model| self.select_model(model, cx));
259 self.set_thread_summary_model(configured_model, cx);
260 }
261
262 /// Selects and sets the inline alternatives for language models based on
263 /// provider name and id.
264 pub fn select_inline_alternative_models(
265 &mut self,
266 alternatives: impl IntoIterator<Item = SelectedModel>,
267 cx: &mut Context<Self>,
268 ) {
269 self.inline_alternatives = alternatives
270 .into_iter()
271 .flat_map(|alternative| {
272 self.select_model(&alternative, cx)
273 .map(|configured_model| configured_model.model)
274 })
275 .collect::<Vec<_>>();
276 }
277
278 pub fn select_model(
279 &mut self,
280 selected_model: &SelectedModel,
281 cx: &mut Context<Self>,
282 ) -> Option<ConfiguredModel> {
283 let provider = self.provider(&selected_model.provider)?;
284 let model = provider
285 .provided_models(cx)
286 .iter()
287 .find(|model| model.id() == selected_model.model)?
288 .clone();
289 Some(ConfiguredModel { provider, model })
290 }
291
292 pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
293 match (self.default_model.as_ref(), model.as_ref()) {
294 (Some(old), Some(new)) if old.is_same_as(new) => {}
295 (None, None) => {}
296 _ => cx.emit(Event::DefaultModelChanged),
297 }
298 self.default_fast_model = maybe!({
299 let provider = &model.as_ref()?.provider;
300 let fast_model = provider.default_fast_model(cx)?;
301 Some(ConfiguredModel {
302 provider: provider.clone(),
303 model: fast_model,
304 })
305 });
306 self.default_model = model;
307 }
308
309 pub fn set_inline_assistant_model(
310 &mut self,
311 model: Option<ConfiguredModel>,
312 cx: &mut Context<Self>,
313 ) {
314 match (self.inline_assistant_model.as_ref(), model.as_ref()) {
315 (Some(old), Some(new)) if old.is_same_as(new) => {}
316 (None, None) => {}
317 _ => cx.emit(Event::InlineAssistantModelChanged),
318 }
319 self.inline_assistant_model = model;
320 }
321
322 pub fn set_commit_message_model(
323 &mut self,
324 model: Option<ConfiguredModel>,
325 cx: &mut Context<Self>,
326 ) {
327 match (self.commit_message_model.as_ref(), model.as_ref()) {
328 (Some(old), Some(new)) if old.is_same_as(new) => {}
329 (None, None) => {}
330 _ => cx.emit(Event::CommitMessageModelChanged),
331 }
332 self.commit_message_model = model;
333 }
334
335 pub fn set_thread_summary_model(
336 &mut self,
337 model: Option<ConfiguredModel>,
338 cx: &mut Context<Self>,
339 ) {
340 match (self.thread_summary_model.as_ref(), model.as_ref()) {
341 (Some(old), Some(new)) if old.is_same_as(new) => {}
342 (None, None) => {}
343 _ => cx.emit(Event::ThreadSummaryModelChanged),
344 }
345 self.thread_summary_model = model;
346 }
347
348 pub fn default_model(&self) -> Option<ConfiguredModel> {
349 #[cfg(debug_assertions)]
350 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
351 return None;
352 }
353
354 self.default_model.clone()
355 }
356
357 pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
358 #[cfg(debug_assertions)]
359 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
360 return None;
361 }
362
363 self.inline_assistant_model
364 .clone()
365 .or_else(|| self.default_model.clone())
366 }
367
368 pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
369 #[cfg(debug_assertions)]
370 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
371 return None;
372 }
373
374 self.commit_message_model
375 .clone()
376 .or_else(|| self.default_fast_model.clone())
377 .or_else(|| self.default_model.clone())
378 }
379
380 pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
381 #[cfg(debug_assertions)]
382 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
383 return None;
384 }
385
386 self.thread_summary_model
387 .clone()
388 .or_else(|| self.default_fast_model.clone())
389 .or_else(|| self.default_model.clone())
390 }
391
392 /// The models to use for inline assists. Returns the union of the active
393 /// model and all inline alternatives. When there are multiple models, the
394 /// user will be able to cycle through results.
395 pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
396 &self.inline_alternatives
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403 use crate::fake_provider::FakeLanguageModelProvider;
404
405 #[gpui::test]
406 fn test_register_providers(cx: &mut App) {
407 let registry = cx.new(|_| LanguageModelRegistry::default());
408
409 let provider = FakeLanguageModelProvider::default();
410 registry.update(cx, |registry, cx| {
411 registry.register_provider(provider.clone(), cx);
412 });
413
414 let providers = registry.read(cx).providers();
415 assert_eq!(providers.len(), 1);
416 assert_eq!(providers[0].id(), provider.id());
417
418 registry.update(cx, |registry, cx| {
419 registry.unregister_provider(provider.id(), cx);
420 });
421
422 let providers = registry.read(cx).providers();
423 assert!(providers.is_empty());
424 }
425}