1use crate::{
2 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
3 LanguageModelProviderState,
4};
5use collections::BTreeMap;
6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
7use std::sync::Arc;
8
9pub fn init(cx: &mut App) {
10 let registry = cx.new(|_cx| LanguageModelRegistry::default());
11 cx.set_global(GlobalLanguageModelRegistry(registry));
12}
13
14struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
15
16impl Global for GlobalLanguageModelRegistry {}
17
18#[derive(Default)]
19pub struct LanguageModelRegistry {
20 default_model: Option<ConfiguredModel>,
21 inline_assistant_model: Option<ConfiguredModel>,
22 commit_message_model: Option<ConfiguredModel>,
23 thread_summary_model: Option<ConfiguredModel>,
24 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
25 inline_alternatives: Vec<Arc<dyn LanguageModel>>,
26}
27
28#[derive(Clone)]
29pub struct ConfiguredModel {
30 pub provider: Arc<dyn LanguageModelProvider>,
31 pub model: Arc<dyn LanguageModel>,
32}
33
34pub enum Event {
35 DefaultModelChanged,
36 InlineAssistantModelChanged,
37 CommitMessageModelChanged,
38 ThreadSummaryModelChanged,
39 ProviderStateChanged,
40 AddedProvider(LanguageModelProviderId),
41 RemovedProvider(LanguageModelProviderId),
42}
43
44impl EventEmitter<Event> for LanguageModelRegistry {}
45
46impl LanguageModelRegistry {
47 pub fn global(cx: &App) -> Entity<Self> {
48 cx.global::<GlobalLanguageModelRegistry>().0.clone()
49 }
50
51 pub fn read_global(cx: &App) -> &Self {
52 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
53 }
54
55 #[cfg(any(test, feature = "test-support"))]
56 pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
57 let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
58 let registry = cx.new(|cx| {
59 let mut registry = Self::default();
60 registry.register_provider(fake_provider.clone(), cx);
61 let model = fake_provider.provided_models(cx)[0].clone();
62 registry.set_default_model(Some(model), cx);
63 registry
64 });
65 cx.set_global(GlobalLanguageModelRegistry(registry));
66 fake_provider
67 }
68
69 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
70 &mut self,
71 provider: T,
72 cx: &mut Context<Self>,
73 ) {
74 let id = provider.id();
75
76 let subscription = provider.subscribe(cx, |_, cx| {
77 cx.emit(Event::ProviderStateChanged);
78 });
79 if let Some(subscription) = subscription {
80 subscription.detach();
81 }
82
83 self.providers.insert(id.clone(), Arc::new(provider));
84 cx.emit(Event::AddedProvider(id));
85 }
86
87 pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
88 if self.providers.remove(&id).is_some() {
89 cx.emit(Event::RemovedProvider(id));
90 }
91 }
92
93 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
94 let zed_provider_id = LanguageModelProviderId("zed.dev".into());
95 let mut providers = Vec::with_capacity(self.providers.len());
96 if let Some(provider) = self.providers.get(&zed_provider_id) {
97 providers.push(provider.clone());
98 }
99 providers.extend(self.providers.values().filter_map(|p| {
100 if p.id() != zed_provider_id {
101 Some(p.clone())
102 } else {
103 None
104 }
105 }));
106 providers
107 }
108
109 pub fn available_models<'a>(
110 &'a self,
111 cx: &'a App,
112 ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
113 self.providers
114 .values()
115 .flat_map(|provider| provider.provided_models(cx))
116 }
117
118 pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
119 self.providers.get(id).cloned()
120 }
121
122 pub fn select_default_model(
123 &mut self,
124 provider: &LanguageModelProviderId,
125 model_id: &LanguageModelId,
126 cx: &mut Context<Self>,
127 ) {
128 let Some(provider) = self.provider(provider) else {
129 return;
130 };
131
132 let models = provider.provided_models(cx);
133 if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
134 self.set_default_model(Some(model), cx);
135 }
136 }
137
138 pub fn select_inline_assistant_model(
139 &mut self,
140 provider: &LanguageModelProviderId,
141 model_id: &LanguageModelId,
142 cx: &mut Context<Self>,
143 ) {
144 let Some(provider) = self.provider(provider) else {
145 return;
146 };
147
148 let models = provider.provided_models(cx);
149 if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
150 self.set_inline_assistant_model(Some(model), cx);
151 }
152 }
153
154 pub fn select_commit_message_model(
155 &mut self,
156 provider: &LanguageModelProviderId,
157 model_id: &LanguageModelId,
158 cx: &mut Context<Self>,
159 ) {
160 let Some(provider) = self.provider(provider) else {
161 return;
162 };
163
164 let models = provider.provided_models(cx);
165 if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
166 self.set_commit_message_model(Some(model), cx);
167 }
168 }
169
170 pub fn select_thread_summary_model(
171 &mut self,
172 provider: &LanguageModelProviderId,
173 model_id: &LanguageModelId,
174 cx: &mut Context<Self>,
175 ) {
176 let Some(provider) = self.provider(provider) else {
177 return;
178 };
179
180 let models = provider.provided_models(cx);
181 if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
182 self.set_thread_summary_model(Some(model), cx);
183 }
184 }
185
186 pub fn set_default_model(
187 &mut self,
188 model: Option<Arc<dyn LanguageModel>>,
189 cx: &mut Context<Self>,
190 ) {
191 if let Some(model) = model {
192 let provider_id = model.provider_id();
193 if let Some(provider) = self.providers.get(&provider_id).cloned() {
194 self.default_model = Some(ConfiguredModel { provider, model });
195 cx.emit(Event::DefaultModelChanged);
196 } else {
197 log::warn!("Active model's provider not found in registry");
198 }
199 } else {
200 self.default_model = None;
201 cx.emit(Event::DefaultModelChanged);
202 }
203 }
204
205 pub fn set_inline_assistant_model(
206 &mut self,
207 model: Option<Arc<dyn LanguageModel>>,
208 cx: &mut Context<Self>,
209 ) {
210 if let Some(model) = model {
211 let provider_id = model.provider_id();
212 if let Some(provider) = self.providers.get(&provider_id).cloned() {
213 self.inline_assistant_model = Some(ConfiguredModel { provider, model });
214 cx.emit(Event::InlineAssistantModelChanged);
215 } else {
216 log::warn!("Inline assistant model's provider not found in registry");
217 }
218 } else {
219 self.inline_assistant_model = None;
220 cx.emit(Event::InlineAssistantModelChanged);
221 }
222 }
223
224 pub fn set_commit_message_model(
225 &mut self,
226 model: Option<Arc<dyn LanguageModel>>,
227 cx: &mut Context<Self>,
228 ) {
229 if let Some(model) = model {
230 let provider_id = model.provider_id();
231 if let Some(provider) = self.providers.get(&provider_id).cloned() {
232 self.commit_message_model = Some(ConfiguredModel { provider, model });
233 cx.emit(Event::CommitMessageModelChanged);
234 } else {
235 log::warn!("Commit message model's provider not found in registry");
236 }
237 } else {
238 self.commit_message_model = None;
239 cx.emit(Event::CommitMessageModelChanged);
240 }
241 }
242
243 pub fn set_thread_summary_model(
244 &mut self,
245 model: Option<Arc<dyn LanguageModel>>,
246 cx: &mut Context<Self>,
247 ) {
248 if let Some(model) = model {
249 let provider_id = model.provider_id();
250 if let Some(provider) = self.providers.get(&provider_id).cloned() {
251 self.thread_summary_model = Some(ConfiguredModel { provider, model });
252 cx.emit(Event::ThreadSummaryModelChanged);
253 } else {
254 log::warn!("Thread summary model's provider not found in registry");
255 }
256 } else {
257 self.thread_summary_model = None;
258 cx.emit(Event::ThreadSummaryModelChanged);
259 }
260 }
261
262 pub fn default_model(&self) -> Option<ConfiguredModel> {
263 #[cfg(debug_assertions)]
264 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
265 return None;
266 }
267
268 self.default_model.clone()
269 }
270
271 pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
272 self.inline_assistant_model
273 .clone()
274 .or_else(|| self.default_model())
275 }
276
277 pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
278 self.commit_message_model
279 .clone()
280 .or_else(|| self.default_model())
281 }
282
283 pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
284 self.thread_summary_model
285 .clone()
286 .or_else(|| self.default_model())
287 }
288
289 /// Selects and sets the inline alternatives for language models based on
290 /// provider name and id.
291 pub fn select_inline_alternative_models(
292 &mut self,
293 alternatives: impl IntoIterator<Item = (LanguageModelProviderId, LanguageModelId)>,
294 cx: &mut Context<Self>,
295 ) {
296 let mut selected_alternatives = Vec::new();
297
298 for (provider_id, model_id) in alternatives {
299 if let Some(provider) = self.providers.get(&provider_id) {
300 if let Some(model) = provider
301 .provided_models(cx)
302 .iter()
303 .find(|m| m.id() == model_id)
304 {
305 selected_alternatives.push(model.clone());
306 }
307 }
308 }
309
310 self.inline_alternatives = selected_alternatives;
311 }
312
313 /// The models to use for inline assists. Returns the union of the active
314 /// model and all inline alternatives. When there are multiple models, the
315 /// user will be able to cycle through results.
316 pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
317 &self.inline_alternatives
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324 use crate::fake_provider::FakeLanguageModelProvider;
325
326 #[gpui::test]
327 fn test_register_providers(cx: &mut App) {
328 let registry = cx.new(|_| LanguageModelRegistry::default());
329
330 registry.update(cx, |registry, cx| {
331 registry.register_provider(FakeLanguageModelProvider, cx);
332 });
333
334 let providers = registry.read(cx).providers();
335 assert_eq!(providers.len(), 1);
336 assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
337
338 registry.update(cx, |registry, cx| {
339 registry.unregister_provider(crate::fake_provider::provider_id(), cx);
340 });
341
342 let providers = registry.read(cx).providers();
343 assert!(providers.is_empty());
344 }
345}