1#[cfg(test)]
2mod fake;
3mod open_ai;
4mod zed;
5
6#[cfg(test)]
7pub use fake::*;
8pub use open_ai::*;
9pub use zed::*;
10
11use crate::{
12 assistant_settings::{AssistantProvider, AssistantSettings},
13 LanguageModel, LanguageModelRequest,
14};
15use anyhow::Result;
16use client::Client;
17use futures::{future::BoxFuture, stream::BoxStream};
18use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
19use settings::{Settings, SettingsStore};
20use std::sync::Arc;
21use std::time::Duration;
22
23pub fn init(client: Arc<Client>, cx: &mut AppContext) {
24 let mut settings_version = 0;
25 let provider = match &AssistantSettings::get_global(cx).provider {
26 AssistantProvider::ZedDotDev { default_model } => {
27 CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
28 default_model.clone(),
29 client.clone(),
30 settings_version,
31 cx,
32 ))
33 }
34 AssistantProvider::OpenAi {
35 default_model,
36 api_url,
37 low_speed_timeout_in_seconds,
38 } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
39 default_model.clone(),
40 api_url.clone(),
41 client.http_client(),
42 low_speed_timeout_in_seconds.map(Duration::from_secs),
43 settings_version,
44 )),
45 };
46 cx.set_global(provider);
47
48 cx.observe_global::<SettingsStore>(move |cx| {
49 settings_version += 1;
50 cx.update_global::<CompletionProvider, _>(|provider, cx| {
51 match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
52 (
53 CompletionProvider::OpenAi(provider),
54 AssistantProvider::OpenAi {
55 default_model,
56 api_url,
57 low_speed_timeout_in_seconds,
58 },
59 ) => {
60 provider.update(
61 default_model.clone(),
62 api_url.clone(),
63 low_speed_timeout_in_seconds.map(Duration::from_secs),
64 settings_version,
65 );
66 }
67 (
68 CompletionProvider::ZedDotDev(provider),
69 AssistantProvider::ZedDotDev { default_model },
70 ) => {
71 provider.update(default_model.clone(), settings_version);
72 }
73 (CompletionProvider::OpenAi(_), AssistantProvider::ZedDotDev { default_model }) => {
74 *provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
75 default_model.clone(),
76 client.clone(),
77 settings_version,
78 cx,
79 ));
80 }
81 (
82 CompletionProvider::ZedDotDev(_),
83 AssistantProvider::OpenAi {
84 default_model,
85 api_url,
86 low_speed_timeout_in_seconds,
87 },
88 ) => {
89 *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
90 default_model.clone(),
91 api_url.clone(),
92 client.http_client(),
93 low_speed_timeout_in_seconds.map(Duration::from_secs),
94 settings_version,
95 ));
96 }
97 #[cfg(test)]
98 (CompletionProvider::Fake(_), _) => unimplemented!(),
99 }
100 })
101 })
102 .detach();
103}
104
105pub enum CompletionProvider {
106 OpenAi(OpenAiCompletionProvider),
107 ZedDotDev(ZedDotDevCompletionProvider),
108 #[cfg(test)]
109 Fake(FakeCompletionProvider),
110}
111
112impl gpui::Global for CompletionProvider {}
113
114impl CompletionProvider {
115 pub fn global(cx: &AppContext) -> &Self {
116 cx.global::<Self>()
117 }
118
119 pub fn settings_version(&self) -> usize {
120 match self {
121 CompletionProvider::OpenAi(provider) => provider.settings_version(),
122 CompletionProvider::ZedDotDev(provider) => provider.settings_version(),
123 #[cfg(test)]
124 CompletionProvider::Fake(_) => unimplemented!(),
125 }
126 }
127
128 pub fn is_authenticated(&self) -> bool {
129 match self {
130 CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
131 CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(),
132 #[cfg(test)]
133 CompletionProvider::Fake(_) => true,
134 }
135 }
136
137 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
138 match self {
139 CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
140 CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx),
141 #[cfg(test)]
142 CompletionProvider::Fake(_) => Task::ready(Ok(())),
143 }
144 }
145
146 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
147 match self {
148 CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
149 CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx),
150 #[cfg(test)]
151 CompletionProvider::Fake(_) => unimplemented!(),
152 }
153 }
154
155 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
156 match self {
157 CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
158 CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())),
159 #[cfg(test)]
160 CompletionProvider::Fake(_) => Task::ready(Ok(())),
161 }
162 }
163
164 pub fn default_model(&self) -> LanguageModel {
165 match self {
166 CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()),
167 CompletionProvider::ZedDotDev(provider) => {
168 LanguageModel::ZedDotDev(provider.default_model())
169 }
170 #[cfg(test)]
171 CompletionProvider::Fake(_) => unimplemented!(),
172 }
173 }
174
175 pub fn count_tokens(
176 &self,
177 request: LanguageModelRequest,
178 cx: &AppContext,
179 ) -> BoxFuture<'static, Result<usize>> {
180 match self {
181 CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
182 CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx),
183 #[cfg(test)]
184 CompletionProvider::Fake(_) => unimplemented!(),
185 }
186 }
187
188 pub fn complete(
189 &self,
190 request: LanguageModelRequest,
191 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
192 match self {
193 CompletionProvider::OpenAi(provider) => provider.complete(request),
194 CompletionProvider::ZedDotDev(provider) => provider.complete(request),
195 #[cfg(test)]
196 CompletionProvider::Fake(provider) => provider.complete(),
197 }
198 }
199}