1use crate::{
2 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
3 LanguageModelProviderState,
4};
5use collections::BTreeMap;
6use gpui::{prelude::*, AppContext, EventEmitter, Global, Model, ModelContext};
7use std::sync::Arc;
8
9pub fn init(cx: &mut AppContext) {
10 let registry = cx.new_model(|_cx| LanguageModelRegistry::default());
11 cx.set_global(GlobalLanguageModelRegistry(registry));
12}
13
14struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
15
16impl Global for GlobalLanguageModelRegistry {}
17
18#[derive(Default)]
19pub struct LanguageModelRegistry {
20 active_model: Option<ActiveModel>,
21 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
22 inline_alternatives: Vec<Arc<dyn LanguageModel>>,
23}
24
25pub struct ActiveModel {
26 provider: Arc<dyn LanguageModelProvider>,
27 model: Option<Arc<dyn LanguageModel>>,
28}
29
30pub enum Event {
31 ActiveModelChanged,
32 ProviderStateChanged,
33 AddedProvider(LanguageModelProviderId),
34 RemovedProvider(LanguageModelProviderId),
35}
36
37impl EventEmitter<Event> for LanguageModelRegistry {}
38
39impl LanguageModelRegistry {
40 pub fn global(cx: &AppContext) -> Model<Self> {
41 cx.global::<GlobalLanguageModelRegistry>().0.clone()
42 }
43
44 pub fn read_global(cx: &AppContext) -> &Self {
45 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
46 }
47
48 #[cfg(any(test, feature = "test-support"))]
49 pub fn test(cx: &mut AppContext) -> crate::fake_provider::FakeLanguageModelProvider {
50 let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
51 let registry = cx.new_model(|cx| {
52 let mut registry = Self::default();
53 registry.register_provider(fake_provider.clone(), cx);
54 let model = fake_provider.provided_models(cx)[0].clone();
55 registry.set_active_model(Some(model), cx);
56 registry
57 });
58 cx.set_global(GlobalLanguageModelRegistry(registry));
59 fake_provider
60 }
61
62 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
63 &mut self,
64 provider: T,
65 cx: &mut ModelContext<Self>,
66 ) {
67 let id = provider.id();
68
69 let subscription = provider.subscribe(cx, |_, cx| {
70 cx.emit(Event::ProviderStateChanged);
71 });
72 if let Some(subscription) = subscription {
73 subscription.detach();
74 }
75
76 self.providers.insert(id.clone(), Arc::new(provider));
77 cx.emit(Event::AddedProvider(id));
78 }
79
80 pub fn unregister_provider(
81 &mut self,
82 id: LanguageModelProviderId,
83 cx: &mut ModelContext<Self>,
84 ) {
85 if self.providers.remove(&id).is_some() {
86 cx.emit(Event::RemovedProvider(id));
87 }
88 }
89
90 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
91 let zed_provider_id = LanguageModelProviderId("zed.dev".into());
92 let mut providers = Vec::with_capacity(self.providers.len());
93 if let Some(provider) = self.providers.get(&zed_provider_id) {
94 providers.push(provider.clone());
95 }
96 providers.extend(self.providers.values().filter_map(|p| {
97 if p.id() != zed_provider_id {
98 Some(p.clone())
99 } else {
100 None
101 }
102 }));
103 providers
104 }
105
106 pub fn available_models<'a>(
107 &'a self,
108 cx: &'a AppContext,
109 ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
110 self.providers
111 .values()
112 .flat_map(|provider| provider.provided_models(cx))
113 }
114
115 pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
116 self.providers.get(id).cloned()
117 }
118
119 pub fn select_active_model(
120 &mut self,
121 provider: &LanguageModelProviderId,
122 model_id: &LanguageModelId,
123 cx: &mut ModelContext<Self>,
124 ) {
125 let Some(provider) = self.provider(provider) else {
126 return;
127 };
128
129 let models = provider.provided_models(cx);
130 if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
131 self.set_active_model(Some(model), cx);
132 }
133 }
134
135 pub fn set_active_provider(
136 &mut self,
137 provider: Option<Arc<dyn LanguageModelProvider>>,
138 cx: &mut ModelContext<Self>,
139 ) {
140 self.active_model = provider.map(|provider| ActiveModel {
141 provider,
142 model: None,
143 });
144 cx.emit(Event::ActiveModelChanged);
145 }
146
147 pub fn set_active_model(
148 &mut self,
149 model: Option<Arc<dyn LanguageModel>>,
150 cx: &mut ModelContext<Self>,
151 ) {
152 if let Some(model) = model {
153 let provider_id = model.provider_id();
154 if let Some(provider) = self.providers.get(&provider_id).cloned() {
155 self.active_model = Some(ActiveModel {
156 provider,
157 model: Some(model),
158 });
159 cx.emit(Event::ActiveModelChanged);
160 } else {
161 log::warn!("Active model's provider not found in registry");
162 }
163 } else {
164 self.active_model = None;
165 cx.emit(Event::ActiveModelChanged);
166 }
167 }
168
169 pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
170 Some(self.active_model.as_ref()?.provider.clone())
171 }
172
173 pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
174 self.active_model.as_ref()?.model.clone()
175 }
176
177 /// Selects and sets the inline alternatives for language models based on
178 /// provider name and id.
179 pub fn select_inline_alternative_models(
180 &mut self,
181 alternatives: impl IntoIterator<Item = (LanguageModelProviderId, LanguageModelId)>,
182 cx: &mut ModelContext<Self>,
183 ) {
184 let mut selected_alternatives = Vec::new();
185
186 for (provider_id, model_id) in alternatives {
187 if let Some(provider) = self.providers.get(&provider_id) {
188 if let Some(model) = provider
189 .provided_models(cx)
190 .iter()
191 .find(|m| m.id() == model_id)
192 {
193 selected_alternatives.push(model.clone());
194 }
195 }
196 }
197
198 self.inline_alternatives = selected_alternatives;
199 }
200
201 /// The models to use for inline assists. Returns the union of the active
202 /// model and all inline alternatives. When there are multiple models, the
203 /// user will be able to cycle through results.
204 pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
205 &self.inline_alternatives
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212 use crate::fake_provider::FakeLanguageModelProvider;
213
214 #[gpui::test]
215 fn test_register_providers(cx: &mut AppContext) {
216 let registry = cx.new_model(|_| LanguageModelRegistry::default());
217
218 registry.update(cx, |registry, cx| {
219 registry.register_provider(FakeLanguageModelProvider, cx);
220 });
221
222 let providers = registry.read(cx).providers();
223 assert_eq!(providers.len(), 1);
224 assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
225
226 registry.update(cx, |registry, cx| {
227 registry.unregister_provider(crate::fake_provider::provider_id(), cx);
228 });
229
230 let providers = registry.read(cx).providers();
231 assert!(providers.is_empty());
232 }
233}