1use crate::{
2 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
3 LanguageModelProviderState,
4};
5use collections::BTreeMap;
6use gpui::{AppContext, EventEmitter, Global, Model, ModelContext};
7use std::sync::Arc;
8use ui::Context;
9
10pub fn init(cx: &mut AppContext) {
11 let registry = cx.new_model(|_cx| LanguageModelRegistry::default());
12 cx.set_global(GlobalLanguageModelRegistry(registry));
13}
14
15struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
16
17impl Global for GlobalLanguageModelRegistry {}
18
19#[derive(Default)]
20pub struct LanguageModelRegistry {
21 active_model: Option<ActiveModel>,
22 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
23 inline_alternatives: Vec<Arc<dyn LanguageModel>>,
24}
25
26pub struct ActiveModel {
27 provider: Arc<dyn LanguageModelProvider>,
28 model: Option<Arc<dyn LanguageModel>>,
29}
30
31pub enum Event {
32 ActiveModelChanged,
33 ProviderStateChanged,
34 AddedProvider(LanguageModelProviderId),
35 RemovedProvider(LanguageModelProviderId),
36}
37
38impl EventEmitter<Event> for LanguageModelRegistry {}
39
40impl LanguageModelRegistry {
41 pub fn global(cx: &AppContext) -> Model<Self> {
42 cx.global::<GlobalLanguageModelRegistry>().0.clone()
43 }
44
45 pub fn read_global(cx: &AppContext) -> &Self {
46 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
47 }
48
49 #[cfg(any(test, feature = "test-support"))]
50 pub fn test(cx: &mut AppContext) -> crate::fake_provider::FakeLanguageModelProvider {
51 let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
52 let registry = cx.new_model(|cx| {
53 let mut registry = Self::default();
54 registry.register_provider(fake_provider.clone(), cx);
55 let model = fake_provider.provided_models(cx)[0].clone();
56 registry.set_active_model(Some(model), cx);
57 registry
58 });
59 cx.set_global(GlobalLanguageModelRegistry(registry));
60 fake_provider
61 }
62
63 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
64 &mut self,
65 provider: T,
66 cx: &mut ModelContext<Self>,
67 ) {
68 let id = provider.id();
69
70 let subscription = provider.subscribe(cx, |_, cx| {
71 cx.emit(Event::ProviderStateChanged);
72 });
73 if let Some(subscription) = subscription {
74 subscription.detach();
75 }
76
77 self.providers.insert(id.clone(), Arc::new(provider));
78 cx.emit(Event::AddedProvider(id));
79 }
80
81 pub fn unregister_provider(
82 &mut self,
83 id: LanguageModelProviderId,
84 cx: &mut ModelContext<Self>,
85 ) {
86 if self.providers.remove(&id).is_some() {
87 cx.emit(Event::RemovedProvider(id));
88 }
89 }
90
91 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
92 let zed_provider_id = LanguageModelProviderId("zed.dev".into());
93 let mut providers = Vec::with_capacity(self.providers.len());
94 if let Some(provider) = self.providers.get(&zed_provider_id) {
95 providers.push(provider.clone());
96 }
97 providers.extend(self.providers.values().filter_map(|p| {
98 if p.id() != zed_provider_id {
99 Some(p.clone())
100 } else {
101 None
102 }
103 }));
104 providers
105 }
106
107 pub fn available_models<'a>(
108 &'a self,
109 cx: &'a AppContext,
110 ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
111 self.providers
112 .values()
113 .flat_map(|provider| provider.provided_models(cx))
114 }
115
116 pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
117 self.providers.get(id).cloned()
118 }
119
120 pub fn select_active_model(
121 &mut self,
122 provider: &LanguageModelProviderId,
123 model_id: &LanguageModelId,
124 cx: &mut ModelContext<Self>,
125 ) {
126 let Some(provider) = self.provider(provider) else {
127 return;
128 };
129
130 let models = provider.provided_models(cx);
131 if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
132 self.set_active_model(Some(model), cx);
133 }
134 }
135
136 pub fn set_active_provider(
137 &mut self,
138 provider: Option<Arc<dyn LanguageModelProvider>>,
139 cx: &mut ModelContext<Self>,
140 ) {
141 self.active_model = provider.map(|provider| ActiveModel {
142 provider,
143 model: None,
144 });
145 cx.emit(Event::ActiveModelChanged);
146 }
147
148 pub fn set_active_model(
149 &mut self,
150 model: Option<Arc<dyn LanguageModel>>,
151 cx: &mut ModelContext<Self>,
152 ) {
153 if let Some(model) = model {
154 let provider_id = model.provider_id();
155 if let Some(provider) = self.providers.get(&provider_id).cloned() {
156 self.active_model = Some(ActiveModel {
157 provider,
158 model: Some(model),
159 });
160 cx.emit(Event::ActiveModelChanged);
161 } else {
162 log::warn!("Active model's provider not found in registry");
163 }
164 } else {
165 self.active_model = None;
166 cx.emit(Event::ActiveModelChanged);
167 }
168 }
169
170 pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
171 Some(self.active_model.as_ref()?.provider.clone())
172 }
173
174 pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
175 self.active_model.as_ref()?.model.clone()
176 }
177
178 /// Selects and sets the inline alternatives for language models based on
179 /// provider name and id.
180 pub fn select_inline_alternative_models(
181 &mut self,
182 alternatives: impl IntoIterator<Item = (LanguageModelProviderId, LanguageModelId)>,
183 cx: &mut ModelContext<Self>,
184 ) {
185 let mut selected_alternatives = Vec::new();
186
187 for (provider_id, model_id) in alternatives {
188 if let Some(provider) = self.providers.get(&provider_id) {
189 if let Some(model) = provider
190 .provided_models(cx)
191 .iter()
192 .find(|m| m.id() == model_id)
193 {
194 selected_alternatives.push(model.clone());
195 }
196 }
197 }
198
199 self.inline_alternatives = selected_alternatives;
200 }
201
202 /// The models to use for inline assists. Returns the union of the active
203 /// model and all inline alternatives. When there are multiple models, the
204 /// user will be able to cycle through results.
205 pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
206 &self.inline_alternatives
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::fake_provider::FakeLanguageModelProvider;
214
215 #[gpui::test]
216 fn test_register_providers(cx: &mut AppContext) {
217 let registry = cx.new_model(|_| LanguageModelRegistry::default());
218
219 registry.update(cx, |registry, cx| {
220 registry.register_provider(FakeLanguageModelProvider, cx);
221 });
222
223 let providers = registry.read(cx).providers();
224 assert_eq!(providers.len(), 1);
225 assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
226
227 registry.update(cx, |registry, cx| {
228 registry.unregister_provider(crate::fake_provider::provider_id(), cx);
229 });
230
231 let providers = registry.read(cx).providers();
232 assert!(providers.is_empty());
233 }
234}