1mod anthropic;
2#[cfg(test)]
3mod fake;
4mod open_ai;
5mod zed;
6
7pub use anthropic::*;
8#[cfg(test)]
9pub use fake::*;
10pub use open_ai::*;
11pub use zed::*;
12
13use crate::{
14 assistant_settings::{AssistantProvider, AssistantSettings},
15 LanguageModel, LanguageModelRequest,
16};
17use anyhow::Result;
18use client::Client;
19use futures::{future::BoxFuture, stream::BoxStream};
20use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
21use settings::{Settings, SettingsStore};
22use std::sync::Arc;
23use std::time::Duration;
24
25pub fn init(client: Arc<Client>, cx: &mut AppContext) {
26 let mut settings_version = 0;
27 let provider = match &AssistantSettings::get_global(cx).provider {
28 AssistantProvider::ZedDotDev { model } => CompletionProvider::ZedDotDev(
29 ZedDotDevCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
30 ),
31 AssistantProvider::OpenAi {
32 model,
33 api_url,
34 low_speed_timeout_in_seconds,
35 } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
36 model.clone(),
37 api_url.clone(),
38 client.http_client(),
39 low_speed_timeout_in_seconds.map(Duration::from_secs),
40 settings_version,
41 )),
42 AssistantProvider::Anthropic {
43 model,
44 api_url,
45 low_speed_timeout_in_seconds,
46 } => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
47 model.clone(),
48 api_url.clone(),
49 client.http_client(),
50 low_speed_timeout_in_seconds.map(Duration::from_secs),
51 settings_version,
52 )),
53 };
54 cx.set_global(provider);
55
56 cx.observe_global::<SettingsStore>(move |cx| {
57 settings_version += 1;
58 cx.update_global::<CompletionProvider, _>(|provider, cx| {
59 match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
60 (
61 CompletionProvider::OpenAi(provider),
62 AssistantProvider::OpenAi {
63 model,
64 api_url,
65 low_speed_timeout_in_seconds,
66 },
67 ) => {
68 provider.update(
69 model.clone(),
70 api_url.clone(),
71 low_speed_timeout_in_seconds.map(Duration::from_secs),
72 settings_version,
73 );
74 }
75 (
76 CompletionProvider::Anthropic(provider),
77 AssistantProvider::Anthropic {
78 model,
79 api_url,
80 low_speed_timeout_in_seconds,
81 },
82 ) => {
83 provider.update(
84 model.clone(),
85 api_url.clone(),
86 low_speed_timeout_in_seconds.map(Duration::from_secs),
87 settings_version,
88 );
89 }
90 (
91 CompletionProvider::ZedDotDev(provider),
92 AssistantProvider::ZedDotDev { model },
93 ) => {
94 provider.update(model.clone(), settings_version);
95 }
96 (_, AssistantProvider::ZedDotDev { model }) => {
97 *provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
98 model.clone(),
99 client.clone(),
100 settings_version,
101 cx,
102 ));
103 }
104 (
105 _,
106 AssistantProvider::OpenAi {
107 model,
108 api_url,
109 low_speed_timeout_in_seconds,
110 },
111 ) => {
112 *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
113 model.clone(),
114 api_url.clone(),
115 client.http_client(),
116 low_speed_timeout_in_seconds.map(Duration::from_secs),
117 settings_version,
118 ));
119 }
120 (
121 _,
122 AssistantProvider::Anthropic {
123 model,
124 api_url,
125 low_speed_timeout_in_seconds,
126 },
127 ) => {
128 *provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
129 model.clone(),
130 api_url.clone(),
131 client.http_client(),
132 low_speed_timeout_in_seconds.map(Duration::from_secs),
133 settings_version,
134 ));
135 }
136 }
137 })
138 })
139 .detach();
140}
141
142pub enum CompletionProvider {
143 OpenAi(OpenAiCompletionProvider),
144 Anthropic(AnthropicCompletionProvider),
145 ZedDotDev(ZedDotDevCompletionProvider),
146 #[cfg(test)]
147 Fake(FakeCompletionProvider),
148}
149
150impl gpui::Global for CompletionProvider {}
151
152impl CompletionProvider {
153 pub fn global(cx: &AppContext) -> &Self {
154 cx.global::<Self>()
155 }
156
157 pub fn available_models(&self) -> Vec<LanguageModel> {
158 match self {
159 CompletionProvider::OpenAi(provider) => provider
160 .available_models()
161 .map(LanguageModel::OpenAi)
162 .collect(),
163 CompletionProvider::Anthropic(provider) => provider
164 .available_models()
165 .map(LanguageModel::Anthropic)
166 .collect(),
167 CompletionProvider::ZedDotDev(provider) => provider
168 .available_models()
169 .map(LanguageModel::ZedDotDev)
170 .collect(),
171 #[cfg(test)]
172 CompletionProvider::Fake(_) => unimplemented!(),
173 }
174 }
175
176 pub fn settings_version(&self) -> usize {
177 match self {
178 CompletionProvider::OpenAi(provider) => provider.settings_version(),
179 CompletionProvider::Anthropic(provider) => provider.settings_version(),
180 CompletionProvider::ZedDotDev(provider) => provider.settings_version(),
181 #[cfg(test)]
182 CompletionProvider::Fake(_) => unimplemented!(),
183 }
184 }
185
186 pub fn is_authenticated(&self) -> bool {
187 match self {
188 CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
189 CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
190 CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(),
191 #[cfg(test)]
192 CompletionProvider::Fake(_) => true,
193 }
194 }
195
196 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
197 match self {
198 CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
199 CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
200 CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx),
201 #[cfg(test)]
202 CompletionProvider::Fake(_) => Task::ready(Ok(())),
203 }
204 }
205
206 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
207 match self {
208 CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
209 CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
210 CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx),
211 #[cfg(test)]
212 CompletionProvider::Fake(_) => unimplemented!(),
213 }
214 }
215
216 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
217 match self {
218 CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
219 CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
220 CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())),
221 #[cfg(test)]
222 CompletionProvider::Fake(_) => Task::ready(Ok(())),
223 }
224 }
225
226 pub fn model(&self) -> LanguageModel {
227 match self {
228 CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()),
229 CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()),
230 CompletionProvider::ZedDotDev(provider) => LanguageModel::ZedDotDev(provider.model()),
231 #[cfg(test)]
232 CompletionProvider::Fake(_) => LanguageModel::default(),
233 }
234 }
235
236 pub fn count_tokens(
237 &self,
238 request: LanguageModelRequest,
239 cx: &AppContext,
240 ) -> BoxFuture<'static, Result<usize>> {
241 match self {
242 CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
243 CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
244 CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx),
245 #[cfg(test)]
246 CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))),
247 }
248 }
249
250 pub fn complete(
251 &self,
252 request: LanguageModelRequest,
253 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
254 match self {
255 CompletionProvider::OpenAi(provider) => provider.complete(request),
256 CompletionProvider::Anthropic(provider) => provider.complete(request),
257 CompletionProvider::ZedDotDev(provider) => provider.complete(request),
258 #[cfg(test)]
259 CompletionProvider::Fake(provider) => provider.complete(),
260 }
261 }
262}