completion_provider.rs

  1mod anthropic;
  2mod cloud;
  3#[cfg(test)]
  4mod fake;
  5mod ollama;
  6mod open_ai;
  7
  8pub use anthropic::*;
  9pub use cloud::*;
 10#[cfg(test)]
 11pub use fake::*;
 12pub use ollama::*;
 13pub use open_ai::*;
 14
 15use crate::{
 16    assistant_settings::{AssistantProvider, AssistantSettings},
 17    LanguageModel, LanguageModelRequest,
 18};
 19use anyhow::Result;
 20use client::Client;
 21use futures::{future::BoxFuture, stream::BoxStream};
 22use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
 23use settings::{Settings, SettingsStore};
 24use std::sync::Arc;
 25use std::time::Duration;
 26
 27/// Choose which model to use for openai provider.
 28/// If the model is not available, try to use the first available model, or fallback to the original model.
 29fn choose_openai_model(
 30    model: &::open_ai::Model,
 31    available_models: &[::open_ai::Model],
 32) -> ::open_ai::Model {
 33    available_models
 34        .iter()
 35        .find(|&m| m == model)
 36        .or_else(|| available_models.first())
 37        .unwrap_or_else(|| model)
 38        .clone()
 39}
 40
 41pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 42    let mut settings_version = 0;
 43    let provider = match &AssistantSettings::get_global(cx).provider {
 44        AssistantProvider::ZedDotDev { model } => CompletionProvider::Cloud(
 45            CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
 46        ),
 47        AssistantProvider::OpenAi {
 48            model,
 49            api_url,
 50            low_speed_timeout_in_seconds,
 51            available_models,
 52        } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
 53            choose_openai_model(model, available_models),
 54            api_url.clone(),
 55            client.http_client(),
 56            low_speed_timeout_in_seconds.map(Duration::from_secs),
 57            settings_version,
 58        )),
 59        AssistantProvider::Anthropic {
 60            model,
 61            api_url,
 62            low_speed_timeout_in_seconds,
 63        } => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
 64            model.clone(),
 65            api_url.clone(),
 66            client.http_client(),
 67            low_speed_timeout_in_seconds.map(Duration::from_secs),
 68            settings_version,
 69        )),
 70        AssistantProvider::Ollama {
 71            model,
 72            api_url,
 73            low_speed_timeout_in_seconds,
 74        } => CompletionProvider::Ollama(OllamaCompletionProvider::new(
 75            model.clone(),
 76            api_url.clone(),
 77            client.http_client(),
 78            low_speed_timeout_in_seconds.map(Duration::from_secs),
 79            settings_version,
 80            cx,
 81        )),
 82    };
 83    cx.set_global(provider);
 84
 85    cx.observe_global::<SettingsStore>(move |cx| {
 86        settings_version += 1;
 87        cx.update_global::<CompletionProvider, _>(|provider, cx| {
 88            match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
 89                (
 90                    CompletionProvider::OpenAi(provider),
 91                    AssistantProvider::OpenAi {
 92                        model,
 93                        api_url,
 94                        low_speed_timeout_in_seconds,
 95                        available_models,
 96                    },
 97                ) => {
 98                    provider.update(
 99                        choose_openai_model(model, available_models),
100                        api_url.clone(),
101                        low_speed_timeout_in_seconds.map(Duration::from_secs),
102                        settings_version,
103                    );
104                }
105                (
106                    CompletionProvider::Anthropic(provider),
107                    AssistantProvider::Anthropic {
108                        model,
109                        api_url,
110                        low_speed_timeout_in_seconds,
111                    },
112                ) => {
113                    provider.update(
114                        model.clone(),
115                        api_url.clone(),
116                        low_speed_timeout_in_seconds.map(Duration::from_secs),
117                        settings_version,
118                    );
119                }
120
121                (
122                    CompletionProvider::Ollama(provider),
123                    AssistantProvider::Ollama {
124                        model,
125                        api_url,
126                        low_speed_timeout_in_seconds,
127                    },
128                ) => {
129                    provider.update(
130                        model.clone(),
131                        api_url.clone(),
132                        low_speed_timeout_in_seconds.map(Duration::from_secs),
133                        settings_version,
134                        cx,
135                    );
136                }
137
138                (CompletionProvider::Cloud(provider), AssistantProvider::ZedDotDev { model }) => {
139                    provider.update(model.clone(), settings_version);
140                }
141                (_, AssistantProvider::ZedDotDev { model }) => {
142                    *provider = CompletionProvider::Cloud(CloudCompletionProvider::new(
143                        model.clone(),
144                        client.clone(),
145                        settings_version,
146                        cx,
147                    ));
148                }
149                (
150                    _,
151                    AssistantProvider::OpenAi {
152                        model,
153                        api_url,
154                        low_speed_timeout_in_seconds,
155                        available_models,
156                    },
157                ) => {
158                    *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
159                        choose_openai_model(model, available_models),
160                        api_url.clone(),
161                        client.http_client(),
162                        low_speed_timeout_in_seconds.map(Duration::from_secs),
163                        settings_version,
164                    ));
165                }
166                (
167                    _,
168                    AssistantProvider::Anthropic {
169                        model,
170                        api_url,
171                        low_speed_timeout_in_seconds,
172                    },
173                ) => {
174                    *provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
175                        model.clone(),
176                        api_url.clone(),
177                        client.http_client(),
178                        low_speed_timeout_in_seconds.map(Duration::from_secs),
179                        settings_version,
180                    ));
181                }
182                (
183                    _,
184                    AssistantProvider::Ollama {
185                        model,
186                        api_url,
187                        low_speed_timeout_in_seconds,
188                    },
189                ) => {
190                    *provider = CompletionProvider::Ollama(OllamaCompletionProvider::new(
191                        model.clone(),
192                        api_url.clone(),
193                        client.http_client(),
194                        low_speed_timeout_in_seconds.map(Duration::from_secs),
195                        settings_version,
196                        cx,
197                    ));
198                }
199            }
200        })
201    })
202    .detach();
203}
204
205pub enum CompletionProvider {
206    OpenAi(OpenAiCompletionProvider),
207    Anthropic(AnthropicCompletionProvider),
208    Cloud(CloudCompletionProvider),
209    #[cfg(test)]
210    Fake(FakeCompletionProvider),
211    Ollama(OllamaCompletionProvider),
212}
213
214impl gpui::Global for CompletionProvider {}
215
216impl CompletionProvider {
217    pub fn global(cx: &AppContext) -> &Self {
218        cx.global::<Self>()
219    }
220
221    pub fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
222        match self {
223            CompletionProvider::OpenAi(provider) => provider
224                .available_models(cx)
225                .map(LanguageModel::OpenAi)
226                .collect(),
227            CompletionProvider::Anthropic(provider) => provider
228                .available_models()
229                .map(LanguageModel::Anthropic)
230                .collect(),
231            CompletionProvider::Cloud(provider) => provider
232                .available_models()
233                .map(LanguageModel::Cloud)
234                .collect(),
235            CompletionProvider::Ollama(provider) => provider
236                .available_models()
237                .map(|model| LanguageModel::Ollama(model.clone()))
238                .collect(),
239            #[cfg(test)]
240            CompletionProvider::Fake(_) => unimplemented!(),
241        }
242    }
243
244    pub fn settings_version(&self) -> usize {
245        match self {
246            CompletionProvider::OpenAi(provider) => provider.settings_version(),
247            CompletionProvider::Anthropic(provider) => provider.settings_version(),
248            CompletionProvider::Cloud(provider) => provider.settings_version(),
249            CompletionProvider::Ollama(provider) => provider.settings_version(),
250            #[cfg(test)]
251            CompletionProvider::Fake(_) => unimplemented!(),
252        }
253    }
254
255    pub fn is_authenticated(&self) -> bool {
256        match self {
257            CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
258            CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
259            CompletionProvider::Cloud(provider) => provider.is_authenticated(),
260            CompletionProvider::Ollama(provider) => provider.is_authenticated(),
261            #[cfg(test)]
262            CompletionProvider::Fake(_) => true,
263        }
264    }
265
266    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
267        match self {
268            CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
269            CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
270            CompletionProvider::Cloud(provider) => provider.authenticate(cx),
271            CompletionProvider::Ollama(provider) => provider.authenticate(cx),
272            #[cfg(test)]
273            CompletionProvider::Fake(_) => Task::ready(Ok(())),
274        }
275    }
276
277    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
278        match self {
279            CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
280            CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
281            CompletionProvider::Cloud(provider) => provider.authentication_prompt(cx),
282            CompletionProvider::Ollama(provider) => provider.authentication_prompt(cx),
283            #[cfg(test)]
284            CompletionProvider::Fake(_) => unimplemented!(),
285        }
286    }
287
288    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
289        match self {
290            CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
291            CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
292            CompletionProvider::Cloud(_) => Task::ready(Ok(())),
293            CompletionProvider::Ollama(provider) => provider.reset_credentials(cx),
294            #[cfg(test)]
295            CompletionProvider::Fake(_) => Task::ready(Ok(())),
296        }
297    }
298
299    pub fn model(&self) -> LanguageModel {
300        match self {
301            CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()),
302            CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()),
303            CompletionProvider::Cloud(provider) => LanguageModel::Cloud(provider.model()),
304            CompletionProvider::Ollama(provider) => LanguageModel::Ollama(provider.model()),
305            #[cfg(test)]
306            CompletionProvider::Fake(_) => LanguageModel::default(),
307        }
308    }
309
310    pub fn count_tokens(
311        &self,
312        request: LanguageModelRequest,
313        cx: &AppContext,
314    ) -> BoxFuture<'static, Result<usize>> {
315        match self {
316            CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
317            CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
318            CompletionProvider::Cloud(provider) => provider.count_tokens(request, cx),
319            CompletionProvider::Ollama(provider) => provider.count_tokens(request, cx),
320            #[cfg(test)]
321            CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))),
322        }
323    }
324
325    pub fn complete(
326        &self,
327        request: LanguageModelRequest,
328    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
329        match self {
330            CompletionProvider::OpenAi(provider) => provider.complete(request),
331            CompletionProvider::Anthropic(provider) => provider.complete(request),
332            CompletionProvider::Cloud(provider) => provider.complete(request),
333            CompletionProvider::Ollama(provider) => provider.complete(request),
334            #[cfg(test)]
335            CompletionProvider::Fake(provider) => provider.complete(),
336        }
337    }
338}