completion_provider.rs

  1mod anthropic;
  2mod cloud;
  3#[cfg(test)]
  4mod fake;
  5mod open_ai;
  6
  7pub use anthropic::*;
  8pub use cloud::*;
  9#[cfg(test)]
 10pub use fake::*;
 11pub use open_ai::*;
 12
 13use crate::{
 14    assistant_settings::{AssistantProvider, AssistantSettings},
 15    LanguageModel, LanguageModelRequest,
 16};
 17use anyhow::Result;
 18use client::Client;
 19use futures::{future::BoxFuture, stream::BoxStream};
 20use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
 21use settings::{Settings, SettingsStore};
 22use std::sync::Arc;
 23use std::time::Duration;
 24
 25pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 26    let mut settings_version = 0;
 27    let provider = match &AssistantSettings::get_global(cx).provider {
 28        AssistantProvider::ZedDotDev { model } => CompletionProvider::Cloud(
 29            CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
 30        ),
 31        AssistantProvider::OpenAi {
 32            model,
 33            api_url,
 34            low_speed_timeout_in_seconds,
 35        } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
 36            model.clone(),
 37            api_url.clone(),
 38            client.http_client(),
 39            low_speed_timeout_in_seconds.map(Duration::from_secs),
 40            settings_version,
 41        )),
 42        AssistantProvider::Anthropic {
 43            model,
 44            api_url,
 45            low_speed_timeout_in_seconds,
 46        } => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
 47            model.clone(),
 48            api_url.clone(),
 49            client.http_client(),
 50            low_speed_timeout_in_seconds.map(Duration::from_secs),
 51            settings_version,
 52        )),
 53    };
 54    cx.set_global(provider);
 55
 56    cx.observe_global::<SettingsStore>(move |cx| {
 57        settings_version += 1;
 58        cx.update_global::<CompletionProvider, _>(|provider, cx| {
 59            match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
 60                (
 61                    CompletionProvider::OpenAi(provider),
 62                    AssistantProvider::OpenAi {
 63                        model,
 64                        api_url,
 65                        low_speed_timeout_in_seconds,
 66                    },
 67                ) => {
 68                    provider.update(
 69                        model.clone(),
 70                        api_url.clone(),
 71                        low_speed_timeout_in_seconds.map(Duration::from_secs),
 72                        settings_version,
 73                    );
 74                }
 75                (
 76                    CompletionProvider::Anthropic(provider),
 77                    AssistantProvider::Anthropic {
 78                        model,
 79                        api_url,
 80                        low_speed_timeout_in_seconds,
 81                    },
 82                ) => {
 83                    provider.update(
 84                        model.clone(),
 85                        api_url.clone(),
 86                        low_speed_timeout_in_seconds.map(Duration::from_secs),
 87                        settings_version,
 88                    );
 89                }
 90                (CompletionProvider::Cloud(provider), AssistantProvider::ZedDotDev { model }) => {
 91                    provider.update(model.clone(), settings_version);
 92                }
 93                (_, AssistantProvider::ZedDotDev { model }) => {
 94                    *provider = CompletionProvider::Cloud(CloudCompletionProvider::new(
 95                        model.clone(),
 96                        client.clone(),
 97                        settings_version,
 98                        cx,
 99                    ));
100                }
101                (
102                    _,
103                    AssistantProvider::OpenAi {
104                        model,
105                        api_url,
106                        low_speed_timeout_in_seconds,
107                    },
108                ) => {
109                    *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
110                        model.clone(),
111                        api_url.clone(),
112                        client.http_client(),
113                        low_speed_timeout_in_seconds.map(Duration::from_secs),
114                        settings_version,
115                    ));
116                }
117                (
118                    _,
119                    AssistantProvider::Anthropic {
120                        model,
121                        api_url,
122                        low_speed_timeout_in_seconds,
123                    },
124                ) => {
125                    *provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
126                        model.clone(),
127                        api_url.clone(),
128                        client.http_client(),
129                        low_speed_timeout_in_seconds.map(Duration::from_secs),
130                        settings_version,
131                    ));
132                }
133            }
134        })
135    })
136    .detach();
137}
138
139pub enum CompletionProvider {
140    OpenAi(OpenAiCompletionProvider),
141    Anthropic(AnthropicCompletionProvider),
142    Cloud(CloudCompletionProvider),
143    #[cfg(test)]
144    Fake(FakeCompletionProvider),
145}
146
147impl gpui::Global for CompletionProvider {}
148
149impl CompletionProvider {
150    pub fn global(cx: &AppContext) -> &Self {
151        cx.global::<Self>()
152    }
153
154    pub fn available_models(&self) -> Vec<LanguageModel> {
155        match self {
156            CompletionProvider::OpenAi(provider) => provider
157                .available_models()
158                .map(LanguageModel::OpenAi)
159                .collect(),
160            CompletionProvider::Anthropic(provider) => provider
161                .available_models()
162                .map(LanguageModel::Anthropic)
163                .collect(),
164            CompletionProvider::Cloud(provider) => provider
165                .available_models()
166                .map(LanguageModel::Cloud)
167                .collect(),
168            #[cfg(test)]
169            CompletionProvider::Fake(_) => unimplemented!(),
170        }
171    }
172
173    pub fn settings_version(&self) -> usize {
174        match self {
175            CompletionProvider::OpenAi(provider) => provider.settings_version(),
176            CompletionProvider::Anthropic(provider) => provider.settings_version(),
177            CompletionProvider::Cloud(provider) => provider.settings_version(),
178            #[cfg(test)]
179            CompletionProvider::Fake(_) => unimplemented!(),
180        }
181    }
182
183    pub fn is_authenticated(&self) -> bool {
184        match self {
185            CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
186            CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
187            CompletionProvider::Cloud(provider) => provider.is_authenticated(),
188            #[cfg(test)]
189            CompletionProvider::Fake(_) => true,
190        }
191    }
192
193    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
194        match self {
195            CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
196            CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
197            CompletionProvider::Cloud(provider) => provider.authenticate(cx),
198            #[cfg(test)]
199            CompletionProvider::Fake(_) => Task::ready(Ok(())),
200        }
201    }
202
203    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
204        match self {
205            CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
206            CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
207            CompletionProvider::Cloud(provider) => provider.authentication_prompt(cx),
208            #[cfg(test)]
209            CompletionProvider::Fake(_) => unimplemented!(),
210        }
211    }
212
213    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
214        match self {
215            CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
216            CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
217            CompletionProvider::Cloud(_) => Task::ready(Ok(())),
218            #[cfg(test)]
219            CompletionProvider::Fake(_) => Task::ready(Ok(())),
220        }
221    }
222
223    pub fn model(&self) -> LanguageModel {
224        match self {
225            CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()),
226            CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()),
227            CompletionProvider::Cloud(provider) => LanguageModel::Cloud(provider.model()),
228            #[cfg(test)]
229            CompletionProvider::Fake(_) => LanguageModel::default(),
230        }
231    }
232
233    pub fn count_tokens(
234        &self,
235        request: LanguageModelRequest,
236        cx: &AppContext,
237    ) -> BoxFuture<'static, Result<usize>> {
238        match self {
239            CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
240            CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
241            CompletionProvider::Cloud(provider) => provider.count_tokens(request, cx),
242            #[cfg(test)]
243            CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))),
244        }
245    }
246
247    pub fn complete(
248        &self,
249        request: LanguageModelRequest,
250    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
251        match self {
252            CompletionProvider::OpenAi(provider) => provider.complete(request),
253            CompletionProvider::Anthropic(provider) => provider.complete(request),
254            CompletionProvider::Cloud(provider) => provider.complete(request),
255            #[cfg(test)]
256            CompletionProvider::Fake(provider) => provider.complete(),
257        }
258    }
259}