1use crate::{
2 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
3 LanguageModelProviderState,
4};
5use collections::BTreeMap;
6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
7use std::sync::Arc;
8
9pub fn init(cx: &mut App) {
10 let registry = cx.new(|_cx| LanguageModelRegistry::default());
11 cx.set_global(GlobalLanguageModelRegistry(registry));
12}
13
14struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
15
16impl Global for GlobalLanguageModelRegistry {}
17
18#[derive(Default)]
19pub struct LanguageModelRegistry {
20 default_model: Option<ConfiguredModel>,
21 inline_assistant_model: Option<ConfiguredModel>,
22 commit_message_model: Option<ConfiguredModel>,
23 thread_summary_model: Option<ConfiguredModel>,
24 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
25 inline_alternatives: Vec<Arc<dyn LanguageModel>>,
26}
27
28pub struct SelectedModel {
29 pub provider: LanguageModelProviderId,
30 pub model: LanguageModelId,
31}
32
33#[derive(Clone)]
34pub struct ConfiguredModel {
35 pub provider: Arc<dyn LanguageModelProvider>,
36 pub model: Arc<dyn LanguageModel>,
37}
38
39impl ConfiguredModel {
40 pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
41 self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
42 }
43}
44
45pub enum Event {
46 DefaultModelChanged,
47 InlineAssistantModelChanged,
48 CommitMessageModelChanged,
49 ThreadSummaryModelChanged,
50 ProviderStateChanged,
51 AddedProvider(LanguageModelProviderId),
52 RemovedProvider(LanguageModelProviderId),
53}
54
55impl EventEmitter<Event> for LanguageModelRegistry {}
56
57impl LanguageModelRegistry {
58 pub fn global(cx: &App) -> Entity<Self> {
59 cx.global::<GlobalLanguageModelRegistry>().0.clone()
60 }
61
62 pub fn read_global(cx: &App) -> &Self {
63 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
64 }
65
66 #[cfg(any(test, feature = "test-support"))]
67 pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
68 let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
69 let registry = cx.new(|cx| {
70 let mut registry = Self::default();
71 registry.register_provider(fake_provider.clone(), cx);
72 let model = fake_provider.provided_models(cx)[0].clone();
73 let configured_model = ConfiguredModel {
74 provider: Arc::new(fake_provider.clone()),
75 model,
76 };
77 registry.set_default_model(Some(configured_model), cx);
78 registry
79 });
80 cx.set_global(GlobalLanguageModelRegistry(registry));
81 fake_provider
82 }
83
84 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
85 &mut self,
86 provider: T,
87 cx: &mut Context<Self>,
88 ) {
89 let id = provider.id();
90
91 let subscription = provider.subscribe(cx, |_, cx| {
92 cx.emit(Event::ProviderStateChanged);
93 });
94 if let Some(subscription) = subscription {
95 subscription.detach();
96 }
97
98 self.providers.insert(id.clone(), Arc::new(provider));
99 cx.emit(Event::AddedProvider(id));
100 }
101
102 pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
103 if self.providers.remove(&id).is_some() {
104 cx.emit(Event::RemovedProvider(id));
105 }
106 }
107
108 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
109 let zed_provider_id = LanguageModelProviderId("zed.dev".into());
110 let mut providers = Vec::with_capacity(self.providers.len());
111 if let Some(provider) = self.providers.get(&zed_provider_id) {
112 providers.push(provider.clone());
113 }
114 providers.extend(self.providers.values().filter_map(|p| {
115 if p.id() != zed_provider_id {
116 Some(p.clone())
117 } else {
118 None
119 }
120 }));
121 providers
122 }
123
124 pub fn available_models<'a>(
125 &'a self,
126 cx: &'a App,
127 ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
128 self.providers
129 .values()
130 .flat_map(|provider| provider.provided_models(cx))
131 }
132
133 pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
134 self.providers.get(id).cloned()
135 }
136
137 pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
138 let configured_model = model.and_then(|model| self.select_model(model, cx));
139 self.set_default_model(configured_model, cx);
140 }
141
142 pub fn select_inline_assistant_model(
143 &mut self,
144 model: Option<&SelectedModel>,
145 cx: &mut Context<Self>,
146 ) {
147 let configured_model = model.and_then(|model| self.select_model(model, cx));
148 self.set_inline_assistant_model(configured_model, cx);
149 }
150
151 pub fn select_commit_message_model(
152 &mut self,
153 model: Option<&SelectedModel>,
154 cx: &mut Context<Self>,
155 ) {
156 let configured_model = model.and_then(|model| self.select_model(model, cx));
157 self.set_commit_message_model(configured_model, cx);
158 }
159
160 pub fn select_thread_summary_model(
161 &mut self,
162 model: Option<&SelectedModel>,
163 cx: &mut Context<Self>,
164 ) {
165 let configured_model = model.and_then(|model| self.select_model(model, cx));
166 self.set_thread_summary_model(configured_model, cx);
167 }
168
169 /// Selects and sets the inline alternatives for language models based on
170 /// provider name and id.
171 pub fn select_inline_alternative_models(
172 &mut self,
173 alternatives: impl IntoIterator<Item = SelectedModel>,
174 cx: &mut Context<Self>,
175 ) {
176 self.inline_alternatives = alternatives
177 .into_iter()
178 .flat_map(|alternative| {
179 self.select_model(&alternative, cx)
180 .map(|configured_model| configured_model.model)
181 })
182 .collect::<Vec<_>>();
183 }
184
185 fn select_model(
186 &mut self,
187 selected_model: &SelectedModel,
188 cx: &mut Context<Self>,
189 ) -> Option<ConfiguredModel> {
190 let provider = self.provider(&selected_model.provider)?;
191 let model = provider
192 .provided_models(cx)
193 .iter()
194 .find(|model| model.id() == selected_model.model)?
195 .clone();
196 Some(ConfiguredModel { provider, model })
197 }
198
199 pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
200 match (self.default_model.as_ref(), model.as_ref()) {
201 (Some(old), Some(new)) if old.is_same_as(new) => {}
202 (None, None) => {}
203 _ => cx.emit(Event::DefaultModelChanged),
204 }
205 self.default_model = model;
206 }
207
208 pub fn set_inline_assistant_model(
209 &mut self,
210 model: Option<ConfiguredModel>,
211 cx: &mut Context<Self>,
212 ) {
213 match (self.inline_assistant_model.as_ref(), model.as_ref()) {
214 (Some(old), Some(new)) if old.is_same_as(new) => {}
215 (None, None) => {}
216 _ => cx.emit(Event::InlineAssistantModelChanged),
217 }
218 self.inline_assistant_model = model;
219 }
220
221 pub fn set_commit_message_model(
222 &mut self,
223 model: Option<ConfiguredModel>,
224 cx: &mut Context<Self>,
225 ) {
226 match (self.commit_message_model.as_ref(), model.as_ref()) {
227 (Some(old), Some(new)) if old.is_same_as(new) => {}
228 (None, None) => {}
229 _ => cx.emit(Event::CommitMessageModelChanged),
230 }
231 self.commit_message_model = model;
232 }
233
234 pub fn set_thread_summary_model(
235 &mut self,
236 model: Option<ConfiguredModel>,
237 cx: &mut Context<Self>,
238 ) {
239 match (self.thread_summary_model.as_ref(), model.as_ref()) {
240 (Some(old), Some(new)) if old.is_same_as(new) => {}
241 (None, None) => {}
242 _ => cx.emit(Event::ThreadSummaryModelChanged),
243 }
244 self.thread_summary_model = model;
245 }
246
247 pub fn default_model(&self) -> Option<ConfiguredModel> {
248 #[cfg(debug_assertions)]
249 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
250 return None;
251 }
252
253 self.default_model.clone()
254 }
255
256 pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
257 self.inline_assistant_model
258 .clone()
259 .or_else(|| self.default_model())
260 }
261
262 pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
263 self.commit_message_model
264 .clone()
265 .or_else(|| self.default_model())
266 }
267
268 pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
269 self.thread_summary_model
270 .clone()
271 .or_else(|| self.default_model())
272 }
273
274 /// The models to use for inline assists. Returns the union of the active
275 /// model and all inline alternatives. When there are multiple models, the
276 /// user will be able to cycle through results.
277 pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
278 &self.inline_alternatives
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use crate::fake_provider::FakeLanguageModelProvider;
286
287 #[gpui::test]
288 fn test_register_providers(cx: &mut App) {
289 let registry = cx.new(|_| LanguageModelRegistry::default());
290
291 registry.update(cx, |registry, cx| {
292 registry.register_provider(FakeLanguageModelProvider, cx);
293 });
294
295 let providers = registry.read(cx).providers();
296 assert_eq!(providers.len(), 1);
297 assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
298
299 registry.update(cx, |registry, cx| {
300 registry.unregister_provider(crate::fake_provider::provider_id(), cx);
301 });
302
303 let providers = registry.read(cx).providers();
304 assert!(providers.is_empty());
305 }
306}