completion_provider.rs

  1mod anthropic;
  2mod cloud;
  3#[cfg(test)]
  4mod fake;
  5mod ollama;
  6mod open_ai;
  7
  8pub use anthropic::*;
  9pub use cloud::*;
 10#[cfg(test)]
 11pub use fake::*;
 12pub use ollama::*;
 13pub use open_ai::*;
 14
 15use crate::{
 16    assistant_settings::{AssistantProvider, AssistantSettings},
 17    LanguageModel, LanguageModelRequest,
 18};
 19use anyhow::Result;
 20use client::Client;
 21use futures::{future::BoxFuture, stream::BoxStream};
 22use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
 23use settings::{Settings, SettingsStore};
 24use std::sync::Arc;
 25use std::time::Duration;
 26
 27pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 28    let mut settings_version = 0;
 29    let provider = match &AssistantSettings::get_global(cx).provider {
 30        AssistantProvider::ZedDotDev { model } => CompletionProvider::Cloud(
 31            CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
 32        ),
 33        AssistantProvider::OpenAi {
 34            model,
 35            api_url,
 36            low_speed_timeout_in_seconds,
 37        } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
 38            model.clone(),
 39            api_url.clone(),
 40            client.http_client(),
 41            low_speed_timeout_in_seconds.map(Duration::from_secs),
 42            settings_version,
 43        )),
 44        AssistantProvider::Anthropic {
 45            model,
 46            api_url,
 47            low_speed_timeout_in_seconds,
 48        } => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
 49            model.clone(),
 50            api_url.clone(),
 51            client.http_client(),
 52            low_speed_timeout_in_seconds.map(Duration::from_secs),
 53            settings_version,
 54        )),
 55        AssistantProvider::Ollama {
 56            model,
 57            api_url,
 58            low_speed_timeout_in_seconds,
 59        } => CompletionProvider::Ollama(OllamaCompletionProvider::new(
 60            model.clone(),
 61            api_url.clone(),
 62            client.http_client(),
 63            low_speed_timeout_in_seconds.map(Duration::from_secs),
 64            settings_version,
 65            cx,
 66        )),
 67    };
 68    cx.set_global(provider);
 69
 70    cx.observe_global::<SettingsStore>(move |cx| {
 71        settings_version += 1;
 72        cx.update_global::<CompletionProvider, _>(|provider, cx| {
 73            match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
 74                (
 75                    CompletionProvider::OpenAi(provider),
 76                    AssistantProvider::OpenAi {
 77                        model,
 78                        api_url,
 79                        low_speed_timeout_in_seconds,
 80                    },
 81                ) => {
 82                    provider.update(
 83                        model.clone(),
 84                        api_url.clone(),
 85                        low_speed_timeout_in_seconds.map(Duration::from_secs),
 86                        settings_version,
 87                    );
 88                }
 89                (
 90                    CompletionProvider::Anthropic(provider),
 91                    AssistantProvider::Anthropic {
 92                        model,
 93                        api_url,
 94                        low_speed_timeout_in_seconds,
 95                    },
 96                ) => {
 97                    provider.update(
 98                        model.clone(),
 99                        api_url.clone(),
100                        low_speed_timeout_in_seconds.map(Duration::from_secs),
101                        settings_version,
102                    );
103                }
104
105                (
106                    CompletionProvider::Ollama(provider),
107                    AssistantProvider::Ollama {
108                        model,
109                        api_url,
110                        low_speed_timeout_in_seconds,
111                    },
112                ) => {
113                    provider.update(
114                        model.clone(),
115                        api_url.clone(),
116                        low_speed_timeout_in_seconds.map(Duration::from_secs),
117                        settings_version,
118                        cx,
119                    );
120                }
121
122                (CompletionProvider::Cloud(provider), AssistantProvider::ZedDotDev { model }) => {
123                    provider.update(model.clone(), settings_version);
124                }
125                (_, AssistantProvider::ZedDotDev { model }) => {
126                    *provider = CompletionProvider::Cloud(CloudCompletionProvider::new(
127                        model.clone(),
128                        client.clone(),
129                        settings_version,
130                        cx,
131                    ));
132                }
133                (
134                    _,
135                    AssistantProvider::OpenAi {
136                        model,
137                        api_url,
138                        low_speed_timeout_in_seconds,
139                    },
140                ) => {
141                    *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
142                        model.clone(),
143                        api_url.clone(),
144                        client.http_client(),
145                        low_speed_timeout_in_seconds.map(Duration::from_secs),
146                        settings_version,
147                    ));
148                }
149                (
150                    _,
151                    AssistantProvider::Anthropic {
152                        model,
153                        api_url,
154                        low_speed_timeout_in_seconds,
155                    },
156                ) => {
157                    *provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
158                        model.clone(),
159                        api_url.clone(),
160                        client.http_client(),
161                        low_speed_timeout_in_seconds.map(Duration::from_secs),
162                        settings_version,
163                    ));
164                }
165                (
166                    _,
167                    AssistantProvider::Ollama {
168                        model,
169                        api_url,
170                        low_speed_timeout_in_seconds,
171                    },
172                ) => {
173                    *provider = CompletionProvider::Ollama(OllamaCompletionProvider::new(
174                        model.clone(),
175                        api_url.clone(),
176                        client.http_client(),
177                        low_speed_timeout_in_seconds.map(Duration::from_secs),
178                        settings_version,
179                        cx,
180                    ));
181                }
182            }
183        })
184    })
185    .detach();
186}
187
188pub enum CompletionProvider {
189    OpenAi(OpenAiCompletionProvider),
190    Anthropic(AnthropicCompletionProvider),
191    Cloud(CloudCompletionProvider),
192    #[cfg(test)]
193    Fake(FakeCompletionProvider),
194    Ollama(OllamaCompletionProvider),
195}
196
197impl gpui::Global for CompletionProvider {}
198
199impl CompletionProvider {
200    pub fn global(cx: &AppContext) -> &Self {
201        cx.global::<Self>()
202    }
203
204    pub fn available_models(&self) -> Vec<LanguageModel> {
205        match self {
206            CompletionProvider::OpenAi(provider) => provider
207                .available_models()
208                .map(LanguageModel::OpenAi)
209                .collect(),
210            CompletionProvider::Anthropic(provider) => provider
211                .available_models()
212                .map(LanguageModel::Anthropic)
213                .collect(),
214            CompletionProvider::Cloud(provider) => provider
215                .available_models()
216                .map(LanguageModel::Cloud)
217                .collect(),
218            CompletionProvider::Ollama(provider) => provider
219                .available_models()
220                .map(|model| LanguageModel::Ollama(model.clone()))
221                .collect(),
222            #[cfg(test)]
223            CompletionProvider::Fake(_) => unimplemented!(),
224        }
225    }
226
227    pub fn settings_version(&self) -> usize {
228        match self {
229            CompletionProvider::OpenAi(provider) => provider.settings_version(),
230            CompletionProvider::Anthropic(provider) => provider.settings_version(),
231            CompletionProvider::Cloud(provider) => provider.settings_version(),
232            CompletionProvider::Ollama(provider) => provider.settings_version(),
233            #[cfg(test)]
234            CompletionProvider::Fake(_) => unimplemented!(),
235        }
236    }
237
238    pub fn is_authenticated(&self) -> bool {
239        match self {
240            CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
241            CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
242            CompletionProvider::Cloud(provider) => provider.is_authenticated(),
243            CompletionProvider::Ollama(provider) => provider.is_authenticated(),
244            #[cfg(test)]
245            CompletionProvider::Fake(_) => true,
246        }
247    }
248
249    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
250        match self {
251            CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
252            CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
253            CompletionProvider::Cloud(provider) => provider.authenticate(cx),
254            CompletionProvider::Ollama(provider) => provider.authenticate(cx),
255            #[cfg(test)]
256            CompletionProvider::Fake(_) => Task::ready(Ok(())),
257        }
258    }
259
260    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
261        match self {
262            CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
263            CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
264            CompletionProvider::Cloud(provider) => provider.authentication_prompt(cx),
265            CompletionProvider::Ollama(provider) => provider.authentication_prompt(cx),
266            #[cfg(test)]
267            CompletionProvider::Fake(_) => unimplemented!(),
268        }
269    }
270
271    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
272        match self {
273            CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
274            CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
275            CompletionProvider::Cloud(_) => Task::ready(Ok(())),
276            CompletionProvider::Ollama(provider) => provider.reset_credentials(cx),
277            #[cfg(test)]
278            CompletionProvider::Fake(_) => Task::ready(Ok(())),
279        }
280    }
281
282    pub fn model(&self) -> LanguageModel {
283        match self {
284            CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()),
285            CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()),
286            CompletionProvider::Cloud(provider) => LanguageModel::Cloud(provider.model()),
287            CompletionProvider::Ollama(provider) => LanguageModel::Ollama(provider.model()),
288            #[cfg(test)]
289            CompletionProvider::Fake(_) => LanguageModel::default(),
290        }
291    }
292
293    pub fn count_tokens(
294        &self,
295        request: LanguageModelRequest,
296        cx: &AppContext,
297    ) -> BoxFuture<'static, Result<usize>> {
298        match self {
299            CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
300            CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
301            CompletionProvider::Cloud(provider) => provider.count_tokens(request, cx),
302            CompletionProvider::Ollama(provider) => provider.count_tokens(request, cx),
303            #[cfg(test)]
304            CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))),
305        }
306    }
307
308    pub fn complete(
309        &self,
310        request: LanguageModelRequest,
311    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
312        match self {
313            CompletionProvider::OpenAi(provider) => provider.complete(request),
314            CompletionProvider::Anthropic(provider) => provider.complete(request),
315            CompletionProvider::Cloud(provider) => provider.complete(request),
316            CompletionProvider::Ollama(provider) => provider.complete(request),
317            #[cfg(test)]
318            CompletionProvider::Fake(provider) => provider.complete(),
319        }
320    }
321}