1use std::sync::Arc;
2
3use ::settings::{Settings, SettingsStore};
4use client::{Client, UserStore};
5use collections::HashSet;
6use gpui::{App, Context, Entity};
7use language_model::{LanguageModelProviderId, LanguageModelRegistry};
8use provider::deepseek::DeepSeekLanguageModelProvider;
9
10mod api_key;
11pub mod provider;
12mod settings;
13pub mod ui;
14
15use crate::provider::anthropic::AnthropicLanguageModelProvider;
16use crate::provider::bedrock::BedrockLanguageModelProvider;
17use crate::provider::cloud::CloudLanguageModelProvider;
18use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
19use crate::provider::google::GoogleLanguageModelProvider;
20use crate::provider::lmstudio::LmStudioLanguageModelProvider;
21pub use crate::provider::mistral::MistralLanguageModelProvider;
22use crate::provider::ollama::OllamaLanguageModelProvider;
23use crate::provider::open_ai::OpenAiLanguageModelProvider;
24use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
25use crate::provider::open_router::OpenRouterLanguageModelProvider;
26use crate::provider::vercel::VercelLanguageModelProvider;
27use crate::provider::x_ai::XAiLanguageModelProvider;
28pub use crate::settings::*;
29
30pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
31 let registry = LanguageModelRegistry::global(cx);
32 registry.update(cx, |registry, cx| {
33 register_language_model_providers(registry, user_store, client.clone(), cx);
34 });
35
36 let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
37 .openai_compatible
38 .keys()
39 .cloned()
40 .collect::<HashSet<_>>();
41
42 registry.update(cx, |registry, cx| {
43 register_openai_compatible_providers(
44 registry,
45 &HashSet::default(),
46 &openai_compatible_providers,
47 client.clone(),
48 cx,
49 );
50 });
51 cx.observe_global::<SettingsStore>(move |cx| {
52 let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
53 .openai_compatible
54 .keys()
55 .cloned()
56 .collect::<HashSet<_>>();
57 if openai_compatible_providers_new != openai_compatible_providers {
58 registry.update(cx, |registry, cx| {
59 register_openai_compatible_providers(
60 registry,
61 &openai_compatible_providers,
62 &openai_compatible_providers_new,
63 client.clone(),
64 cx,
65 );
66 });
67 openai_compatible_providers = openai_compatible_providers_new;
68 }
69 })
70 .detach();
71}
72
73fn register_openai_compatible_providers(
74 registry: &mut LanguageModelRegistry,
75 old: &HashSet<Arc<str>>,
76 new: &HashSet<Arc<str>>,
77 client: Arc<Client>,
78 cx: &mut Context<LanguageModelRegistry>,
79) {
80 for provider_id in old {
81 if !new.contains(provider_id) {
82 registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
83 }
84 }
85
86 for provider_id in new {
87 if !old.contains(provider_id) {
88 registry.register_provider(
89 Arc::new(OpenAiCompatibleLanguageModelProvider::new(
90 provider_id.clone(),
91 client.http_client(),
92 cx,
93 )),
94 cx,
95 );
96 }
97 }
98}
99
100fn register_language_model_providers(
101 registry: &mut LanguageModelRegistry,
102 user_store: Entity<UserStore>,
103 client: Arc<Client>,
104 cx: &mut Context<LanguageModelRegistry>,
105) {
106 registry.register_provider(
107 Arc::new(CloudLanguageModelProvider::new(
108 user_store,
109 client.clone(),
110 cx,
111 )),
112 cx,
113 );
114 registry.register_provider(
115 Arc::new(AnthropicLanguageModelProvider::new(
116 client.http_client(),
117 cx,
118 )),
119 cx,
120 );
121 registry.register_provider(
122 Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
123 cx,
124 );
125 registry.register_provider(
126 Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
127 cx,
128 );
129 registry.register_provider(
130 Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
131 cx,
132 );
133 registry.register_provider(
134 Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
135 cx,
136 );
137 registry.register_provider(
138 Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
139 cx,
140 );
141 registry.register_provider(
142 MistralLanguageModelProvider::global(client.http_client(), cx),
143 cx,
144 );
145 registry.register_provider(
146 Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
147 cx,
148 );
149 registry.register_provider(
150 Arc::new(OpenRouterLanguageModelProvider::new(
151 client.http_client(),
152 cx,
153 )),
154 cx,
155 );
156 registry.register_provider(
157 Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
158 cx,
159 );
160 registry.register_provider(
161 Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
162 cx,
163 );
164 registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
165}