1use crate::{
2 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
3 LanguageModelProviderState,
4};
5use collections::BTreeMap;
6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
7use std::sync::Arc;
8
9pub fn init(cx: &mut App) {
10 let registry = cx.new(|_cx| LanguageModelRegistry::default());
11 cx.set_global(GlobalLanguageModelRegistry(registry));
12}
13
14struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
15
16impl Global for GlobalLanguageModelRegistry {}
17
18#[derive(Default)]
19pub struct LanguageModelRegistry {
20 active_model: Option<ActiveModel>,
21 editor_model: Option<ActiveModel>,
22 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
23 inline_alternatives: Vec<Arc<dyn LanguageModel>>,
24}
25
26pub struct ActiveModel {
27 provider: Arc<dyn LanguageModelProvider>,
28 model: Option<Arc<dyn LanguageModel>>,
29}
30
31pub enum Event {
32 ActiveModelChanged,
33 EditorModelChanged,
34 ProviderStateChanged,
35 AddedProvider(LanguageModelProviderId),
36 RemovedProvider(LanguageModelProviderId),
37}
38
39impl EventEmitter<Event> for LanguageModelRegistry {}
40
41impl LanguageModelRegistry {
42 pub fn global(cx: &App) -> Entity<Self> {
43 cx.global::<GlobalLanguageModelRegistry>().0.clone()
44 }
45
46 pub fn read_global(cx: &App) -> &Self {
47 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
48 }
49
50 #[cfg(any(test, feature = "test-support"))]
51 pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
52 let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
53 let registry = cx.new(|cx| {
54 let mut registry = Self::default();
55 registry.register_provider(fake_provider.clone(), cx);
56 let model = fake_provider.provided_models(cx)[0].clone();
57 registry.set_active_model(Some(model), cx);
58 registry
59 });
60 cx.set_global(GlobalLanguageModelRegistry(registry));
61 fake_provider
62 }
63
64 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
65 &mut self,
66 provider: T,
67 cx: &mut Context<Self>,
68 ) {
69 let id = provider.id();
70
71 let subscription = provider.subscribe(cx, |_, cx| {
72 cx.emit(Event::ProviderStateChanged);
73 });
74 if let Some(subscription) = subscription {
75 subscription.detach();
76 }
77
78 self.providers.insert(id.clone(), Arc::new(provider));
79 cx.emit(Event::AddedProvider(id));
80 }
81
82 pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
83 if self.providers.remove(&id).is_some() {
84 cx.emit(Event::RemovedProvider(id));
85 }
86 }
87
88 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
89 let zed_provider_id = LanguageModelProviderId("zed.dev".into());
90 let mut providers = Vec::with_capacity(self.providers.len());
91 if let Some(provider) = self.providers.get(&zed_provider_id) {
92 providers.push(provider.clone());
93 }
94 providers.extend(self.providers.values().filter_map(|p| {
95 if p.id() != zed_provider_id {
96 Some(p.clone())
97 } else {
98 None
99 }
100 }));
101 providers
102 }
103
104 pub fn available_models<'a>(
105 &'a self,
106 cx: &'a App,
107 ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
108 self.providers
109 .values()
110 .flat_map(|provider| provider.provided_models(cx))
111 }
112
113 pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
114 self.providers.get(id).cloned()
115 }
116
117 pub fn select_active_model(
118 &mut self,
119 provider: &LanguageModelProviderId,
120 model_id: &LanguageModelId,
121 cx: &mut Context<Self>,
122 ) {
123 let Some(provider) = self.provider(provider) else {
124 return;
125 };
126
127 let models = provider.provided_models(cx);
128 if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
129 self.set_active_model(Some(model), cx);
130 }
131 }
132
133 pub fn select_editor_model(
134 &mut self,
135 provider: &LanguageModelProviderId,
136 model_id: &LanguageModelId,
137 cx: &mut Context<Self>,
138 ) {
139 let Some(provider) = self.provider(provider) else {
140 return;
141 };
142
143 let models = provider.provided_models(cx);
144 if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
145 self.set_editor_model(Some(model), cx);
146 }
147 }
148
149 pub fn set_active_provider(
150 &mut self,
151 provider: Option<Arc<dyn LanguageModelProvider>>,
152 cx: &mut Context<Self>,
153 ) {
154 self.active_model = provider.map(|provider| ActiveModel {
155 provider,
156 model: None,
157 });
158 cx.emit(Event::ActiveModelChanged);
159 }
160
161 pub fn set_active_model(
162 &mut self,
163 model: Option<Arc<dyn LanguageModel>>,
164 cx: &mut Context<Self>,
165 ) {
166 if let Some(model) = model {
167 let provider_id = model.provider_id();
168 if let Some(provider) = self.providers.get(&provider_id).cloned() {
169 self.active_model = Some(ActiveModel {
170 provider,
171 model: Some(model),
172 });
173 cx.emit(Event::ActiveModelChanged);
174 } else {
175 log::warn!("Active model's provider not found in registry");
176 }
177 } else {
178 self.active_model = None;
179 cx.emit(Event::ActiveModelChanged);
180 }
181 }
182
183 pub fn set_editor_model(
184 &mut self,
185 model: Option<Arc<dyn LanguageModel>>,
186 cx: &mut Context<Self>,
187 ) {
188 if let Some(model) = model {
189 let provider_id = model.provider_id();
190 if let Some(provider) = self.providers.get(&provider_id).cloned() {
191 self.editor_model = Some(ActiveModel {
192 provider,
193 model: Some(model),
194 });
195 cx.emit(Event::EditorModelChanged);
196 } else {
197 log::warn!("Active model's provider not found in registry");
198 }
199 } else {
200 self.editor_model = None;
201 cx.emit(Event::EditorModelChanged);
202 }
203 }
204
205 pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
206 #[cfg(debug_assertions)]
207 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
208 return None;
209 }
210
211 Some(self.active_model.as_ref()?.provider.clone())
212 }
213
214 pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
215 self.active_model.as_ref()?.model.clone()
216 }
217
218 pub fn editor_model(&self) -> Option<Arc<dyn LanguageModel>> {
219 self.editor_model.as_ref()?.model.clone()
220 }
221
222 /// Selects and sets the inline alternatives for language models based on
223 /// provider name and id.
224 pub fn select_inline_alternative_models(
225 &mut self,
226 alternatives: impl IntoIterator<Item = (LanguageModelProviderId, LanguageModelId)>,
227 cx: &mut Context<Self>,
228 ) {
229 let mut selected_alternatives = Vec::new();
230
231 for (provider_id, model_id) in alternatives {
232 if let Some(provider) = self.providers.get(&provider_id) {
233 if let Some(model) = provider
234 .provided_models(cx)
235 .iter()
236 .find(|m| m.id() == model_id)
237 {
238 selected_alternatives.push(model.clone());
239 }
240 }
241 }
242
243 self.inline_alternatives = selected_alternatives;
244 }
245
246 /// The models to use for inline assists. Returns the union of the active
247 /// model and all inline alternatives. When there are multiple models, the
248 /// user will be able to cycle through results.
249 pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
250 &self.inline_alternatives
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257 use crate::fake_provider::FakeLanguageModelProvider;
258
259 #[gpui::test]
260 fn test_register_providers(cx: &mut App) {
261 let registry = cx.new(|_| LanguageModelRegistry::default());
262
263 registry.update(cx, |registry, cx| {
264 registry.register_provider(FakeLanguageModelProvider, cx);
265 });
266
267 let providers = registry.read(cx).providers();
268 assert_eq!(providers.len(), 1);
269 assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
270
271 registry.update(cx, |registry, cx| {
272 registry.unregister_provider(crate::fake_provider::provider_id(), cx);
273 });
274
275 let providers = registry.read(cx).providers();
276 assert!(providers.is_empty());
277 }
278}