1use crate::{
2 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
3 LanguageModelProviderState,
4};
5use collections::BTreeMap;
6use gpui::{prelude::*, App, Context, Entity, EventEmitter, Global};
7use std::sync::Arc;
8
9pub fn init(cx: &mut App) {
10 let registry = cx.new(|_cx| LanguageModelRegistry::default());
11 cx.set_global(GlobalLanguageModelRegistry(registry));
12}
13
14struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
15
16impl Global for GlobalLanguageModelRegistry {}
17
18#[derive(Default)]
19pub struct LanguageModelRegistry {
20 active_model: Option<ActiveModel>,
21 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
22 inline_alternatives: Vec<Arc<dyn LanguageModel>>,
23}
24
25pub struct ActiveModel {
26 provider: Arc<dyn LanguageModelProvider>,
27 model: Option<Arc<dyn LanguageModel>>,
28}
29
30pub enum Event {
31 ActiveModelChanged,
32 ProviderStateChanged,
33 AddedProvider(LanguageModelProviderId),
34 RemovedProvider(LanguageModelProviderId),
35}
36
37impl EventEmitter<Event> for LanguageModelRegistry {}
38
39impl LanguageModelRegistry {
40 pub fn global(cx: &App) -> Entity<Self> {
41 cx.global::<GlobalLanguageModelRegistry>().0.clone()
42 }
43
44 pub fn read_global(cx: &App) -> &Self {
45 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
46 }
47
48 #[cfg(any(test, feature = "test-support"))]
49 pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
50 let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
51 let registry = cx.new(|cx| {
52 let mut registry = Self::default();
53 registry.register_provider(fake_provider.clone(), cx);
54 let model = fake_provider.provided_models(cx)[0].clone();
55 registry.set_active_model(Some(model), cx);
56 registry
57 });
58 cx.set_global(GlobalLanguageModelRegistry(registry));
59 fake_provider
60 }
61
62 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
63 &mut self,
64 provider: T,
65 cx: &mut Context<Self>,
66 ) {
67 let id = provider.id();
68
69 let subscription = provider.subscribe(cx, |_, cx| {
70 cx.emit(Event::ProviderStateChanged);
71 });
72 if let Some(subscription) = subscription {
73 subscription.detach();
74 }
75
76 self.providers.insert(id.clone(), Arc::new(provider));
77 cx.emit(Event::AddedProvider(id));
78 }
79
80 pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
81 if self.providers.remove(&id).is_some() {
82 cx.emit(Event::RemovedProvider(id));
83 }
84 }
85
86 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
87 let zed_provider_id = LanguageModelProviderId("zed.dev".into());
88 let mut providers = Vec::with_capacity(self.providers.len());
89 if let Some(provider) = self.providers.get(&zed_provider_id) {
90 providers.push(provider.clone());
91 }
92 providers.extend(self.providers.values().filter_map(|p| {
93 if p.id() != zed_provider_id {
94 Some(p.clone())
95 } else {
96 None
97 }
98 }));
99 providers
100 }
101
102 pub fn available_models<'a>(
103 &'a self,
104 cx: &'a App,
105 ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
106 self.providers
107 .values()
108 .flat_map(|provider| provider.provided_models(cx))
109 }
110
111 pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
112 self.providers.get(id).cloned()
113 }
114
115 pub fn select_active_model(
116 &mut self,
117 provider: &LanguageModelProviderId,
118 model_id: &LanguageModelId,
119 cx: &mut Context<Self>,
120 ) {
121 let Some(provider) = self.provider(provider) else {
122 return;
123 };
124
125 let models = provider.provided_models(cx);
126 if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
127 self.set_active_model(Some(model), cx);
128 }
129 }
130
131 pub fn set_active_provider(
132 &mut self,
133 provider: Option<Arc<dyn LanguageModelProvider>>,
134 cx: &mut Context<Self>,
135 ) {
136 self.active_model = provider.map(|provider| ActiveModel {
137 provider,
138 model: None,
139 });
140 cx.emit(Event::ActiveModelChanged);
141 }
142
143 pub fn set_active_model(
144 &mut self,
145 model: Option<Arc<dyn LanguageModel>>,
146 cx: &mut Context<Self>,
147 ) {
148 if let Some(model) = model {
149 let provider_id = model.provider_id();
150 if let Some(provider) = self.providers.get(&provider_id).cloned() {
151 self.active_model = Some(ActiveModel {
152 provider,
153 model: Some(model),
154 });
155 cx.emit(Event::ActiveModelChanged);
156 } else {
157 log::warn!("Active model's provider not found in registry");
158 }
159 } else {
160 self.active_model = None;
161 cx.emit(Event::ActiveModelChanged);
162 }
163 }
164
165 pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
166 Some(self.active_model.as_ref()?.provider.clone())
167 }
168
169 pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
170 self.active_model.as_ref()?.model.clone()
171 }
172
173 /// Selects and sets the inline alternatives for language models based on
174 /// provider name and id.
175 pub fn select_inline_alternative_models(
176 &mut self,
177 alternatives: impl IntoIterator<Item = (LanguageModelProviderId, LanguageModelId)>,
178 cx: &mut Context<Self>,
179 ) {
180 let mut selected_alternatives = Vec::new();
181
182 for (provider_id, model_id) in alternatives {
183 if let Some(provider) = self.providers.get(&provider_id) {
184 if let Some(model) = provider
185 .provided_models(cx)
186 .iter()
187 .find(|m| m.id() == model_id)
188 {
189 selected_alternatives.push(model.clone());
190 }
191 }
192 }
193
194 self.inline_alternatives = selected_alternatives;
195 }
196
197 /// The models to use for inline assists. Returns the union of the active
198 /// model and all inline alternatives. When there are multiple models, the
199 /// user will be able to cycle through results.
200 pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
201 &self.inline_alternatives
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use crate::fake_provider::FakeLanguageModelProvider;
209
210 #[gpui::test]
211 fn test_register_providers(cx: &mut App) {
212 let registry = cx.new(|_| LanguageModelRegistry::default());
213
214 registry.update(cx, |registry, cx| {
215 registry.register_provider(FakeLanguageModelProvider, cx);
216 });
217
218 let providers = registry.read(cx).providers();
219 assert_eq!(providers.len(), 1);
220 assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
221
222 registry.update(cx, |registry, cx| {
223 registry.unregister_provider(crate::fake_provider::provider_id(), cx);
224 });
225
226 let providers = registry.read(cx).providers();
227 assert!(providers.is_empty());
228 }
229}