1use crate::{
2 LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
3 LanguageModelProviderState,
4};
5use collections::BTreeMap;
6use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
7use std::sync::Arc;
8use util::maybe;
9
10pub fn init(cx: &mut App) {
11 let registry = cx.new(|_cx| LanguageModelRegistry::default());
12 cx.set_global(GlobalLanguageModelRegistry(registry));
13}
14
15struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
16
17impl Global for GlobalLanguageModelRegistry {}
18
19#[derive(Default)]
20pub struct LanguageModelRegistry {
21 default_model: Option<ConfiguredModel>,
22 default_fast_model: Option<ConfiguredModel>,
23 inline_assistant_model: Option<ConfiguredModel>,
24 commit_message_model: Option<ConfiguredModel>,
25 thread_summary_model: Option<ConfiguredModel>,
26 providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
27 inline_alternatives: Vec<Arc<dyn LanguageModel>>,
28}
29
30pub struct SelectedModel {
31 pub provider: LanguageModelProviderId,
32 pub model: LanguageModelId,
33}
34
35#[derive(Clone)]
36pub struct ConfiguredModel {
37 pub provider: Arc<dyn LanguageModelProvider>,
38 pub model: Arc<dyn LanguageModel>,
39}
40
41impl ConfiguredModel {
42 pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
43 self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
44 }
45
46 pub fn is_provided_by_zed(&self) -> bool {
47 self.provider.id().0 == crate::ZED_CLOUD_PROVIDER_ID
48 }
49}
50
51pub enum Event {
52 DefaultModelChanged,
53 InlineAssistantModelChanged,
54 CommitMessageModelChanged,
55 ThreadSummaryModelChanged,
56 ProviderStateChanged,
57 AddedProvider(LanguageModelProviderId),
58 RemovedProvider(LanguageModelProviderId),
59}
60
61impl EventEmitter<Event> for LanguageModelRegistry {}
62
63impl LanguageModelRegistry {
64 pub fn global(cx: &App) -> Entity<Self> {
65 cx.global::<GlobalLanguageModelRegistry>().0.clone()
66 }
67
68 pub fn read_global(cx: &App) -> &Self {
69 cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
70 }
71
72 #[cfg(any(test, feature = "test-support"))]
73 pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
74 let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
75 let registry = cx.new(|cx| {
76 let mut registry = Self::default();
77 registry.register_provider(fake_provider.clone(), cx);
78 let model = fake_provider.provided_models(cx)[0].clone();
79 let configured_model = ConfiguredModel {
80 provider: Arc::new(fake_provider.clone()),
81 model,
82 };
83 registry.set_default_model(Some(configured_model), cx);
84 registry
85 });
86 cx.set_global(GlobalLanguageModelRegistry(registry));
87 fake_provider
88 }
89
90 pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
91 &mut self,
92 provider: T,
93 cx: &mut Context<Self>,
94 ) {
95 let id = provider.id();
96
97 let subscription = provider.subscribe(cx, |_, cx| {
98 cx.emit(Event::ProviderStateChanged);
99 });
100 if let Some(subscription) = subscription {
101 subscription.detach();
102 }
103
104 self.providers.insert(id.clone(), Arc::new(provider));
105 cx.emit(Event::AddedProvider(id));
106 }
107
108 pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
109 if self.providers.remove(&id).is_some() {
110 cx.emit(Event::RemovedProvider(id));
111 }
112 }
113
114 pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
115 let zed_provider_id = LanguageModelProviderId("zed.dev".into());
116 let mut providers = Vec::with_capacity(self.providers.len());
117 if let Some(provider) = self.providers.get(&zed_provider_id) {
118 providers.push(provider.clone());
119 }
120 providers.extend(self.providers.values().filter_map(|p| {
121 if p.id() != zed_provider_id {
122 Some(p.clone())
123 } else {
124 None
125 }
126 }));
127 providers
128 }
129
130 pub fn available_models<'a>(
131 &'a self,
132 cx: &'a App,
133 ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
134 self.providers
135 .values()
136 .flat_map(|provider| provider.provided_models(cx))
137 }
138
139 pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
140 self.providers.get(id).cloned()
141 }
142
143 pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
144 let configured_model = model.and_then(|model| self.select_model(model, cx));
145 self.set_default_model(configured_model, cx);
146 }
147
148 pub fn select_inline_assistant_model(
149 &mut self,
150 model: Option<&SelectedModel>,
151 cx: &mut Context<Self>,
152 ) {
153 let configured_model = model.and_then(|model| self.select_model(model, cx));
154 self.set_inline_assistant_model(configured_model, cx);
155 }
156
157 pub fn select_commit_message_model(
158 &mut self,
159 model: Option<&SelectedModel>,
160 cx: &mut Context<Self>,
161 ) {
162 let configured_model = model.and_then(|model| self.select_model(model, cx));
163 self.set_commit_message_model(configured_model, cx);
164 }
165
166 pub fn select_thread_summary_model(
167 &mut self,
168 model: Option<&SelectedModel>,
169 cx: &mut Context<Self>,
170 ) {
171 let configured_model = model.and_then(|model| self.select_model(model, cx));
172 self.set_thread_summary_model(configured_model, cx);
173 }
174
175 /// Selects and sets the inline alternatives for language models based on
176 /// provider name and id.
177 pub fn select_inline_alternative_models(
178 &mut self,
179 alternatives: impl IntoIterator<Item = SelectedModel>,
180 cx: &mut Context<Self>,
181 ) {
182 self.inline_alternatives = alternatives
183 .into_iter()
184 .flat_map(|alternative| {
185 self.select_model(&alternative, cx)
186 .map(|configured_model| configured_model.model)
187 })
188 .collect::<Vec<_>>();
189 }
190
191 pub fn select_model(
192 &mut self,
193 selected_model: &SelectedModel,
194 cx: &mut Context<Self>,
195 ) -> Option<ConfiguredModel> {
196 let provider = self.provider(&selected_model.provider)?;
197 let model = provider
198 .provided_models(cx)
199 .iter()
200 .find(|model| model.id() == selected_model.model)?
201 .clone();
202 Some(ConfiguredModel { provider, model })
203 }
204
205 pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
206 match (self.default_model.as_ref(), model.as_ref()) {
207 (Some(old), Some(new)) if old.is_same_as(new) => {}
208 (None, None) => {}
209 _ => cx.emit(Event::DefaultModelChanged),
210 }
211 self.default_fast_model = maybe!({
212 let provider = &model.as_ref()?.provider;
213 let fast_model = provider.default_fast_model(cx)?;
214 Some(ConfiguredModel {
215 provider: provider.clone(),
216 model: fast_model,
217 })
218 });
219 self.default_model = model;
220 }
221
222 pub fn set_inline_assistant_model(
223 &mut self,
224 model: Option<ConfiguredModel>,
225 cx: &mut Context<Self>,
226 ) {
227 match (self.inline_assistant_model.as_ref(), model.as_ref()) {
228 (Some(old), Some(new)) if old.is_same_as(new) => {}
229 (None, None) => {}
230 _ => cx.emit(Event::InlineAssistantModelChanged),
231 }
232 self.inline_assistant_model = model;
233 }
234
235 pub fn set_commit_message_model(
236 &mut self,
237 model: Option<ConfiguredModel>,
238 cx: &mut Context<Self>,
239 ) {
240 match (self.commit_message_model.as_ref(), model.as_ref()) {
241 (Some(old), Some(new)) if old.is_same_as(new) => {}
242 (None, None) => {}
243 _ => cx.emit(Event::CommitMessageModelChanged),
244 }
245 self.commit_message_model = model;
246 }
247
248 pub fn set_thread_summary_model(
249 &mut self,
250 model: Option<ConfiguredModel>,
251 cx: &mut Context<Self>,
252 ) {
253 match (self.thread_summary_model.as_ref(), model.as_ref()) {
254 (Some(old), Some(new)) if old.is_same_as(new) => {}
255 (None, None) => {}
256 _ => cx.emit(Event::ThreadSummaryModelChanged),
257 }
258 self.thread_summary_model = model;
259 }
260
261 pub fn default_model(&self) -> Option<ConfiguredModel> {
262 #[cfg(debug_assertions)]
263 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
264 return None;
265 }
266
267 self.default_model.clone()
268 }
269
270 pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
271 #[cfg(debug_assertions)]
272 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
273 return None;
274 }
275
276 self.inline_assistant_model
277 .clone()
278 .or_else(|| self.default_model.clone())
279 }
280
281 pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
282 #[cfg(debug_assertions)]
283 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
284 return None;
285 }
286
287 self.commit_message_model
288 .clone()
289 .or_else(|| self.default_model.clone())
290 }
291
292 pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
293 #[cfg(debug_assertions)]
294 if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
295 return None;
296 }
297
298 self.thread_summary_model
299 .clone()
300 .or_else(|| self.default_fast_model.clone())
301 .or_else(|| self.default_model.clone())
302 }
303
304 /// The models to use for inline assists. Returns the union of the active
305 /// model and all inline alternatives. When there are multiple models, the
306 /// user will be able to cycle through results.
307 pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
308 &self.inline_alternatives
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315 use crate::fake_provider::FakeLanguageModelProvider;
316
317 #[gpui::test]
318 fn test_register_providers(cx: &mut App) {
319 let registry = cx.new(|_| LanguageModelRegistry::default());
320
321 registry.update(cx, |registry, cx| {
322 registry.register_provider(FakeLanguageModelProvider, cx);
323 });
324
325 let providers = registry.read(cx).providers();
326 assert_eq!(providers.len(), 1);
327 assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
328
329 registry.update(cx, |registry, cx| {
330 registry.unregister_provider(crate::fake_provider::provider_id(), cx);
331 });
332
333 let providers = registry.read(cx).providers();
334 assert!(providers.is_empty());
335 }
336}