1use crate::{
2 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
3 LanguageModelProviderState,
4};
5use collections::BTreeMap;
6use gpui::{prelude::*, App, Context, Entity, EventEmitter, Global};
7use std::sync::Arc;
8
9pub fn init(cx: &mut App) {
10 let registry = cx.new(|_cx| LanguageModelRegistry::default());
11 cx.set_global(GlobalLanguageModelRegistry(registry));
12}
13
14struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
15
16impl Global for GlobalLanguageModelRegistry {}
17
18#[derive(Default)]
19pub struct LanguageModelRegistry {
20 active_model: Option<ActiveModel>,
21 editor_model: Option<ActiveModel>,
22 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
23 inline_alternatives: Vec<Arc<dyn LanguageModel>>,
24}
25
26pub struct ActiveModel {
27 provider: Arc<dyn LanguageModelProvider>,
28 model: Option<Arc<dyn LanguageModel>>,
29}
30
31pub enum Event {
32 ActiveModelChanged,
33 EditorModelChanged,
34 ProviderStateChanged,
35 AddedProvider(LanguageModelProviderId),
36 RemovedProvider(LanguageModelProviderId),
37}
38
39impl EventEmitter<Event> for LanguageModelRegistry {}
40
41impl LanguageModelRegistry {
42 pub fn global(cx: &App) -> Entity<Self> {
43 cx.global::<GlobalLanguageModelRegistry>().0.clone()
44 }
45
46 pub fn read_global(cx: &App) -> &Self {
47 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
48 }
49
50 #[cfg(any(test, feature = "test-support"))]
51 pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
52 let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
53 let registry = cx.new(|cx| {
54 let mut registry = Self::default();
55 registry.register_provider(fake_provider.clone(), cx);
56 let model = fake_provider.provided_models(cx)[0].clone();
57 registry.set_active_model(Some(model), cx);
58 registry
59 });
60 cx.set_global(GlobalLanguageModelRegistry(registry));
61 fake_provider
62 }
63
64 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
65 &mut self,
66 provider: T,
67 cx: &mut Context<Self>,
68 ) {
69 let id = provider.id();
70
71 let subscription = provider.subscribe(cx, |_, cx| {
72 cx.emit(Event::ProviderStateChanged);
73 });
74 if let Some(subscription) = subscription {
75 subscription.detach();
76 }
77
78 self.providers.insert(id.clone(), Arc::new(provider));
79 cx.emit(Event::AddedProvider(id));
80 }
81
82 pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
83 if self.providers.remove(&id).is_some() {
84 cx.emit(Event::RemovedProvider(id));
85 }
86 }
87
88 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
89 let zed_provider_id = LanguageModelProviderId("zed.dev".into());
90 let mut providers = Vec::with_capacity(self.providers.len());
91 if let Some(provider) = self.providers.get(&zed_provider_id) {
92 providers.push(provider.clone());
93 }
94 providers.extend(self.providers.values().filter_map(|p| {
95 if p.id() != zed_provider_id {
96 Some(p.clone())
97 } else {
98 None
99 }
100 }));
101 providers
102 }
103
104 pub fn available_models<'a>(
105 &'a self,
106 cx: &'a App,
107 ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
108 self.providers
109 .values()
110 .flat_map(|provider| provider.provided_models(cx))
111 }
112
113 pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
114 self.providers.get(id).cloned()
115 }
116
117 pub fn select_active_model(
118 &mut self,
119 provider: &LanguageModelProviderId,
120 model_id: &LanguageModelId,
121 cx: &mut Context<Self>,
122 ) {
123 let Some(provider) = self.provider(provider) else {
124 return;
125 };
126
127 let models = provider.provided_models(cx);
128 if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
129 self.set_active_model(Some(model), cx);
130 }
131 }
132
133 pub fn select_editor_model(
134 &mut self,
135 provider: &LanguageModelProviderId,
136 model_id: &LanguageModelId,
137 cx: &mut Context<Self>,
138 ) {
139 let Some(provider) = self.provider(provider) else {
140 return;
141 };
142
143 let models = provider.provided_models(cx);
144 if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
145 self.set_editor_model(Some(model), cx);
146 }
147 }
148
149 pub fn set_active_provider(
150 &mut self,
151 provider: Option<Arc<dyn LanguageModelProvider>>,
152 cx: &mut Context<Self>,
153 ) {
154 self.active_model = provider.map(|provider| ActiveModel {
155 provider,
156 model: None,
157 });
158 cx.emit(Event::ActiveModelChanged);
159 }
160
161 pub fn set_active_model(
162 &mut self,
163 model: Option<Arc<dyn LanguageModel>>,
164 cx: &mut Context<Self>,
165 ) {
166 if let Some(model) = model {
167 let provider_id = model.provider_id();
168 if let Some(provider) = self.providers.get(&provider_id).cloned() {
169 self.active_model = Some(ActiveModel {
170 provider,
171 model: Some(model),
172 });
173 cx.emit(Event::ActiveModelChanged);
174 } else {
175 log::warn!("Active model's provider not found in registry");
176 }
177 } else {
178 self.active_model = None;
179 cx.emit(Event::ActiveModelChanged);
180 }
181 }
182
183 pub fn set_editor_model(
184 &mut self,
185 model: Option<Arc<dyn LanguageModel>>,
186 cx: &mut Context<Self>,
187 ) {
188 if let Some(model) = model {
189 let provider_id = model.provider_id();
190 if let Some(provider) = self.providers.get(&provider_id).cloned() {
191 self.editor_model = Some(ActiveModel {
192 provider,
193 model: Some(model),
194 });
195 cx.emit(Event::EditorModelChanged);
196 } else {
197 log::warn!("Active model's provider not found in registry");
198 }
199 } else {
200 self.editor_model = None;
201 cx.emit(Event::EditorModelChanged);
202 }
203 }
204
205 pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
206 Some(self.active_model.as_ref()?.provider.clone())
207 }
208
209 pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
210 self.active_model.as_ref()?.model.clone()
211 }
212
213 pub fn editor_model(&self) -> Option<Arc<dyn LanguageModel>> {
214 self.editor_model.as_ref()?.model.clone()
215 }
216
217 /// Selects and sets the inline alternatives for language models based on
218 /// provider name and id.
219 pub fn select_inline_alternative_models(
220 &mut self,
221 alternatives: impl IntoIterator<Item = (LanguageModelProviderId, LanguageModelId)>,
222 cx: &mut Context<Self>,
223 ) {
224 let mut selected_alternatives = Vec::new();
225
226 for (provider_id, model_id) in alternatives {
227 if let Some(provider) = self.providers.get(&provider_id) {
228 if let Some(model) = provider
229 .provided_models(cx)
230 .iter()
231 .find(|m| m.id() == model_id)
232 {
233 selected_alternatives.push(model.clone());
234 }
235 }
236 }
237
238 self.inline_alternatives = selected_alternatives;
239 }
240
241 /// The models to use for inline assists. Returns the union of the active
242 /// model and all inline alternatives. When there are multiple models, the
243 /// user will be able to cycle through results.
244 pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
245 &self.inline_alternatives
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use crate::fake_provider::FakeLanguageModelProvider;
253
254 #[gpui::test]
255 fn test_register_providers(cx: &mut App) {
256 let registry = cx.new(|_| LanguageModelRegistry::default());
257
258 registry.update(cx, |registry, cx| {
259 registry.register_provider(FakeLanguageModelProvider, cx);
260 });
261
262 let providers = registry.read(cx).providers();
263 assert_eq!(providers.len(), 1);
264 assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
265
266 registry.update(cx, |registry, cx| {
267 registry.unregister_provider(crate::fake_provider::provider_id(), cx);
268 });
269
270 let providers = registry.read(cx).providers();
271 assert!(providers.is_empty());
272 }
273}