1mod anthropic;
2mod cloud;
3#[cfg(test)]
4mod fake;
5mod ollama;
6mod open_ai;
7
8pub use anthropic::*;
9pub use cloud::*;
10#[cfg(test)]
11pub use fake::*;
12pub use ollama::*;
13pub use open_ai::*;
14
15use crate::{
16 assistant_settings::{AssistantProvider, AssistantSettings},
17 LanguageModel, LanguageModelRequest,
18};
19use anyhow::Result;
20use client::Client;
21use futures::{future::BoxFuture, stream::BoxStream};
22use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
23use settings::{Settings, SettingsStore};
24use std::sync::Arc;
25use std::time::Duration;
26
27pub fn init(client: Arc<Client>, cx: &mut AppContext) {
28 let mut settings_version = 0;
29 let provider = match &AssistantSettings::get_global(cx).provider {
30 AssistantProvider::ZedDotDev { model } => CompletionProvider::Cloud(
31 CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
32 ),
33 AssistantProvider::OpenAi {
34 model,
35 api_url,
36 low_speed_timeout_in_seconds,
37 } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
38 model.clone(),
39 api_url.clone(),
40 client.http_client(),
41 low_speed_timeout_in_seconds.map(Duration::from_secs),
42 settings_version,
43 )),
44 AssistantProvider::Anthropic {
45 model,
46 api_url,
47 low_speed_timeout_in_seconds,
48 } => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
49 model.clone(),
50 api_url.clone(),
51 client.http_client(),
52 low_speed_timeout_in_seconds.map(Duration::from_secs),
53 settings_version,
54 )),
55 AssistantProvider::Ollama {
56 model,
57 api_url,
58 low_speed_timeout_in_seconds,
59 } => CompletionProvider::Ollama(OllamaCompletionProvider::new(
60 model.clone(),
61 api_url.clone(),
62 client.http_client(),
63 low_speed_timeout_in_seconds.map(Duration::from_secs),
64 settings_version,
65 )),
66 };
67 cx.set_global(provider);
68
69 cx.observe_global::<SettingsStore>(move |cx| {
70 settings_version += 1;
71 cx.update_global::<CompletionProvider, _>(|provider, cx| {
72 match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
73 (
74 CompletionProvider::OpenAi(provider),
75 AssistantProvider::OpenAi {
76 model,
77 api_url,
78 low_speed_timeout_in_seconds,
79 },
80 ) => {
81 provider.update(
82 model.clone(),
83 api_url.clone(),
84 low_speed_timeout_in_seconds.map(Duration::from_secs),
85 settings_version,
86 );
87 }
88 (
89 CompletionProvider::Anthropic(provider),
90 AssistantProvider::Anthropic {
91 model,
92 api_url,
93 low_speed_timeout_in_seconds,
94 },
95 ) => {
96 provider.update(
97 model.clone(),
98 api_url.clone(),
99 low_speed_timeout_in_seconds.map(Duration::from_secs),
100 settings_version,
101 );
102 }
103
104 (
105 CompletionProvider::Ollama(provider),
106 AssistantProvider::Ollama {
107 model,
108 api_url,
109 low_speed_timeout_in_seconds,
110 },
111 ) => {
112 provider.update(
113 model.clone(),
114 api_url.clone(),
115 low_speed_timeout_in_seconds.map(Duration::from_secs),
116 settings_version,
117 );
118 }
119
120 (CompletionProvider::Cloud(provider), AssistantProvider::ZedDotDev { model }) => {
121 provider.update(model.clone(), settings_version);
122 }
123 (_, AssistantProvider::ZedDotDev { model }) => {
124 *provider = CompletionProvider::Cloud(CloudCompletionProvider::new(
125 model.clone(),
126 client.clone(),
127 settings_version,
128 cx,
129 ));
130 }
131 (
132 _,
133 AssistantProvider::OpenAi {
134 model,
135 api_url,
136 low_speed_timeout_in_seconds,
137 },
138 ) => {
139 *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
140 model.clone(),
141 api_url.clone(),
142 client.http_client(),
143 low_speed_timeout_in_seconds.map(Duration::from_secs),
144 settings_version,
145 ));
146 }
147 (
148 _,
149 AssistantProvider::Anthropic {
150 model,
151 api_url,
152 low_speed_timeout_in_seconds,
153 },
154 ) => {
155 *provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
156 model.clone(),
157 api_url.clone(),
158 client.http_client(),
159 low_speed_timeout_in_seconds.map(Duration::from_secs),
160 settings_version,
161 ));
162 }
163 (
164 _,
165 AssistantProvider::Ollama {
166 model,
167 api_url,
168 low_speed_timeout_in_seconds,
169 },
170 ) => {
171 *provider = CompletionProvider::Ollama(OllamaCompletionProvider::new(
172 model.clone(),
173 api_url.clone(),
174 client.http_client(),
175 low_speed_timeout_in_seconds.map(Duration::from_secs),
176 settings_version,
177 ));
178 }
179 }
180 })
181 })
182 .detach();
183}
184
185pub enum CompletionProvider {
186 OpenAi(OpenAiCompletionProvider),
187 Anthropic(AnthropicCompletionProvider),
188 Cloud(CloudCompletionProvider),
189 #[cfg(test)]
190 Fake(FakeCompletionProvider),
191 Ollama(OllamaCompletionProvider),
192}
193
194impl gpui::Global for CompletionProvider {}
195
196impl CompletionProvider {
197 pub fn global(cx: &AppContext) -> &Self {
198 cx.global::<Self>()
199 }
200
201 pub fn available_models(&self) -> Vec<LanguageModel> {
202 match self {
203 CompletionProvider::OpenAi(provider) => provider
204 .available_models()
205 .map(LanguageModel::OpenAi)
206 .collect(),
207 CompletionProvider::Anthropic(provider) => provider
208 .available_models()
209 .map(LanguageModel::Anthropic)
210 .collect(),
211 CompletionProvider::Cloud(provider) => provider
212 .available_models()
213 .map(LanguageModel::Cloud)
214 .collect(),
215 CompletionProvider::Ollama(provider) => provider
216 .available_models()
217 .map(|model| LanguageModel::Ollama(model.clone()))
218 .collect(),
219 #[cfg(test)]
220 CompletionProvider::Fake(_) => unimplemented!(),
221 }
222 }
223
224 pub fn settings_version(&self) -> usize {
225 match self {
226 CompletionProvider::OpenAi(provider) => provider.settings_version(),
227 CompletionProvider::Anthropic(provider) => provider.settings_version(),
228 CompletionProvider::Cloud(provider) => provider.settings_version(),
229 CompletionProvider::Ollama(provider) => provider.settings_version(),
230 #[cfg(test)]
231 CompletionProvider::Fake(_) => unimplemented!(),
232 }
233 }
234
235 pub fn is_authenticated(&self) -> bool {
236 match self {
237 CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
238 CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
239 CompletionProvider::Cloud(provider) => provider.is_authenticated(),
240 CompletionProvider::Ollama(provider) => provider.is_authenticated(),
241 #[cfg(test)]
242 CompletionProvider::Fake(_) => true,
243 }
244 }
245
246 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
247 match self {
248 CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
249 CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
250 CompletionProvider::Cloud(provider) => provider.authenticate(cx),
251 CompletionProvider::Ollama(provider) => provider.authenticate(cx),
252 #[cfg(test)]
253 CompletionProvider::Fake(_) => Task::ready(Ok(())),
254 }
255 }
256
257 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
258 match self {
259 CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
260 CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
261 CompletionProvider::Cloud(provider) => provider.authentication_prompt(cx),
262 CompletionProvider::Ollama(provider) => provider.authentication_prompt(cx),
263 #[cfg(test)]
264 CompletionProvider::Fake(_) => unimplemented!(),
265 }
266 }
267
268 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
269 match self {
270 CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
271 CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
272 CompletionProvider::Cloud(_) => Task::ready(Ok(())),
273 CompletionProvider::Ollama(provider) => provider.reset_credentials(cx),
274 #[cfg(test)]
275 CompletionProvider::Fake(_) => Task::ready(Ok(())),
276 }
277 }
278
279 pub fn model(&self) -> LanguageModel {
280 match self {
281 CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()),
282 CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()),
283 CompletionProvider::Cloud(provider) => LanguageModel::Cloud(provider.model()),
284 CompletionProvider::Ollama(provider) => LanguageModel::Ollama(provider.model()),
285 #[cfg(test)]
286 CompletionProvider::Fake(_) => LanguageModel::default(),
287 }
288 }
289
290 pub fn count_tokens(
291 &self,
292 request: LanguageModelRequest,
293 cx: &AppContext,
294 ) -> BoxFuture<'static, Result<usize>> {
295 match self {
296 CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
297 CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
298 CompletionProvider::Cloud(provider) => provider.count_tokens(request, cx),
299 CompletionProvider::Ollama(provider) => provider.count_tokens(request, cx),
300 #[cfg(test)]
301 CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))),
302 }
303 }
304
305 pub fn complete(
306 &self,
307 request: LanguageModelRequest,
308 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
309 match self {
310 CompletionProvider::OpenAi(provider) => provider.complete(request),
311 CompletionProvider::Anthropic(provider) => provider.complete(request),
312 CompletionProvider::Cloud(provider) => provider.complete(request),
313 CompletionProvider::Ollama(provider) => provider.complete(request),
314 #[cfg(test)]
315 CompletionProvider::Fake(provider) => provider.complete(),
316 }
317 }
318}