completion_provider.rs

  1mod anthropic;
  2#[cfg(test)]
  3mod fake;
  4mod open_ai;
  5mod zed;
  6
  7pub use anthropic::*;
  8#[cfg(test)]
  9pub use fake::*;
 10pub use open_ai::*;
 11pub use zed::*;
 12
 13use crate::{
 14    assistant_settings::{AssistantProvider, AssistantSettings},
 15    LanguageModel, LanguageModelRequest,
 16};
 17use anyhow::Result;
 18use client::Client;
 19use futures::{future::BoxFuture, stream::BoxStream};
 20use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
 21use settings::{Settings, SettingsStore};
 22use std::sync::Arc;
 23use std::time::Duration;
 24
 25pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 26    let mut settings_version = 0;
 27    let provider = match &AssistantSettings::get_global(cx).provider {
 28        AssistantProvider::ZedDotDev { model } => CompletionProvider::ZedDotDev(
 29            ZedDotDevCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
 30        ),
 31        AssistantProvider::OpenAi {
 32            model,
 33            api_url,
 34            low_speed_timeout_in_seconds,
 35        } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
 36            model.clone(),
 37            api_url.clone(),
 38            client.http_client(),
 39            low_speed_timeout_in_seconds.map(Duration::from_secs),
 40            settings_version,
 41        )),
 42        AssistantProvider::Anthropic {
 43            model,
 44            api_url,
 45            low_speed_timeout_in_seconds,
 46        } => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
 47            model.clone(),
 48            api_url.clone(),
 49            client.http_client(),
 50            low_speed_timeout_in_seconds.map(Duration::from_secs),
 51            settings_version,
 52        )),
 53    };
 54    cx.set_global(provider);
 55
 56    cx.observe_global::<SettingsStore>(move |cx| {
 57        settings_version += 1;
 58        cx.update_global::<CompletionProvider, _>(|provider, cx| {
 59            match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
 60                (
 61                    CompletionProvider::OpenAi(provider),
 62                    AssistantProvider::OpenAi {
 63                        model,
 64                        api_url,
 65                        low_speed_timeout_in_seconds,
 66                    },
 67                ) => {
 68                    provider.update(
 69                        model.clone(),
 70                        api_url.clone(),
 71                        low_speed_timeout_in_seconds.map(Duration::from_secs),
 72                        settings_version,
 73                    );
 74                }
 75                (
 76                    CompletionProvider::Anthropic(provider),
 77                    AssistantProvider::Anthropic {
 78                        model,
 79                        api_url,
 80                        low_speed_timeout_in_seconds,
 81                    },
 82                ) => {
 83                    provider.update(
 84                        model.clone(),
 85                        api_url.clone(),
 86                        low_speed_timeout_in_seconds.map(Duration::from_secs),
 87                        settings_version,
 88                    );
 89                }
 90                (
 91                    CompletionProvider::ZedDotDev(provider),
 92                    AssistantProvider::ZedDotDev { model },
 93                ) => {
 94                    provider.update(model.clone(), settings_version);
 95                }
 96                (_, AssistantProvider::ZedDotDev { model }) => {
 97                    *provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
 98                        model.clone(),
 99                        client.clone(),
100                        settings_version,
101                        cx,
102                    ));
103                }
104                (
105                    _,
106                    AssistantProvider::OpenAi {
107                        model,
108                        api_url,
109                        low_speed_timeout_in_seconds,
110                    },
111                ) => {
112                    *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
113                        model.clone(),
114                        api_url.clone(),
115                        client.http_client(),
116                        low_speed_timeout_in_seconds.map(Duration::from_secs),
117                        settings_version,
118                    ));
119                }
120                (
121                    _,
122                    AssistantProvider::Anthropic {
123                        model,
124                        api_url,
125                        low_speed_timeout_in_seconds,
126                    },
127                ) => {
128                    *provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
129                        model.clone(),
130                        api_url.clone(),
131                        client.http_client(),
132                        low_speed_timeout_in_seconds.map(Duration::from_secs),
133                        settings_version,
134                    ));
135                }
136            }
137        })
138    })
139    .detach();
140}
141
142pub enum CompletionProvider {
143    OpenAi(OpenAiCompletionProvider),
144    Anthropic(AnthropicCompletionProvider),
145    ZedDotDev(ZedDotDevCompletionProvider),
146    #[cfg(test)]
147    Fake(FakeCompletionProvider),
148}
149
150impl gpui::Global for CompletionProvider {}
151
152impl CompletionProvider {
153    pub fn global(cx: &AppContext) -> &Self {
154        cx.global::<Self>()
155    }
156
157    pub fn available_models(&self) -> Vec<LanguageModel> {
158        match self {
159            CompletionProvider::OpenAi(provider) => provider
160                .available_models()
161                .map(LanguageModel::OpenAi)
162                .collect(),
163            CompletionProvider::Anthropic(provider) => provider
164                .available_models()
165                .map(LanguageModel::Anthropic)
166                .collect(),
167            CompletionProvider::ZedDotDev(provider) => provider
168                .available_models()
169                .map(LanguageModel::ZedDotDev)
170                .collect(),
171            #[cfg(test)]
172            CompletionProvider::Fake(_) => unimplemented!(),
173        }
174    }
175
176    pub fn settings_version(&self) -> usize {
177        match self {
178            CompletionProvider::OpenAi(provider) => provider.settings_version(),
179            CompletionProvider::Anthropic(provider) => provider.settings_version(),
180            CompletionProvider::ZedDotDev(provider) => provider.settings_version(),
181            #[cfg(test)]
182            CompletionProvider::Fake(_) => unimplemented!(),
183        }
184    }
185
186    pub fn is_authenticated(&self) -> bool {
187        match self {
188            CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
189            CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
190            CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(),
191            #[cfg(test)]
192            CompletionProvider::Fake(_) => true,
193        }
194    }
195
196    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
197        match self {
198            CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
199            CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
200            CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx),
201            #[cfg(test)]
202            CompletionProvider::Fake(_) => Task::ready(Ok(())),
203        }
204    }
205
206    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
207        match self {
208            CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
209            CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
210            CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx),
211            #[cfg(test)]
212            CompletionProvider::Fake(_) => unimplemented!(),
213        }
214    }
215
216    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
217        match self {
218            CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
219            CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
220            CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())),
221            #[cfg(test)]
222            CompletionProvider::Fake(_) => Task::ready(Ok(())),
223        }
224    }
225
226    pub fn model(&self) -> LanguageModel {
227        match self {
228            CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()),
229            CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()),
230            CompletionProvider::ZedDotDev(provider) => LanguageModel::ZedDotDev(provider.model()),
231            #[cfg(test)]
232            CompletionProvider::Fake(_) => LanguageModel::default(),
233        }
234    }
235
236    pub fn count_tokens(
237        &self,
238        request: LanguageModelRequest,
239        cx: &AppContext,
240    ) -> BoxFuture<'static, Result<usize>> {
241        match self {
242            CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
243            CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
244            CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx),
245            #[cfg(test)]
246            CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))),
247        }
248    }
249
250    pub fn complete(
251        &self,
252        request: LanguageModelRequest,
253    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
254        match self {
255            CompletionProvider::OpenAi(provider) => provider.complete(request),
256            CompletionProvider::Anthropic(provider) => provider.complete(request),
257            CompletionProvider::ZedDotDev(provider) => provider.complete(request),
258            #[cfg(test)]
259            CompletionProvider::Fake(provider) => provider.complete(),
260        }
261    }
262}