1use std::sync::Arc;
2
3use anyhow::Result;
4use copilot::copilot_chat::{
5 ChatMessage, CopilotChat, Model as CopilotChatModel, Request as CopilotChatRequest,
6 Role as CopilotChatRole,
7};
8use copilot::{Copilot, Status};
9use futures::future::BoxFuture;
10use futures::stream::BoxStream;
11use futures::{FutureExt, StreamExt};
12use gpui::{
13 percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, Model,
14 ModelContext, Render, Subscription, Task, Transformation,
15};
16use settings::{Settings, SettingsStore};
17use std::time::Duration;
18use strum::IntoEnumIterator;
19use ui::{
20 div, v_flex, Button, ButtonCommon, Clickable, Color, Context, FixedWidth, IconName,
21 IconPosition, IconSize, IntoElement, Label, LabelCommon, ParentElement, Styled, ViewContext,
22 VisualContext, WindowContext,
23};
24
25use crate::settings::AllLanguageModelSettings;
26use crate::LanguageModelProviderState;
27use crate::{
28 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
29 LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, Role,
30};
31
32use super::open_ai::count_open_ai_tokens;
33
34const PROVIDER_ID: &str = "copilot_chat";
35const PROVIDER_NAME: &str = "GitHub Copilot Chat";
36
37#[derive(Default, Clone, Debug, PartialEq)]
38pub struct CopilotChatSettings {
39 pub low_speed_timeout: Option<Duration>,
40}
41
42pub struct CopilotChatLanguageModelProvider {
43 state: Model<State>,
44}
45
46pub struct State {
47 _copilot_chat_subscription: Option<Subscription>,
48 _settings_subscription: Subscription,
49}
50
51impl CopilotChatLanguageModelProvider {
52 pub fn new(cx: &mut AppContext) -> Self {
53 let state = cx.new_model(|cx| {
54 let _copilot_chat_subscription = CopilotChat::global(cx)
55 .map(|copilot_chat| cx.observe(&copilot_chat, |_, _, cx| cx.notify()));
56 State {
57 _copilot_chat_subscription,
58 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
59 cx.notify();
60 }),
61 }
62 });
63
64 Self { state }
65 }
66}
67
68impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
69 fn subscribe<T: 'static>(&self, cx: &mut ModelContext<T>) -> Option<Subscription> {
70 Some(cx.observe(&self.state, |_, _, cx| {
71 cx.notify();
72 }))
73 }
74}
75
76impl LanguageModelProvider for CopilotChatLanguageModelProvider {
77 fn id(&self) -> LanguageModelProviderId {
78 LanguageModelProviderId(PROVIDER_ID.into())
79 }
80
81 fn name(&self) -> LanguageModelProviderName {
82 LanguageModelProviderName(PROVIDER_NAME.into())
83 }
84
85 fn provided_models(&self, _cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
86 CopilotChatModel::iter()
87 .map(|model| Arc::new(CopilotChatLanguageModel { model }) as Arc<dyn LanguageModel>)
88 .collect()
89 }
90
91 fn is_authenticated(&self, cx: &AppContext) -> bool {
92 CopilotChat::global(cx)
93 .map(|m| m.read(cx).is_authenticated())
94 .unwrap_or(false)
95 }
96
97 fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
98 let result = if self.is_authenticated(cx) {
99 Ok(())
100 } else if let Some(copilot) = Copilot::global(cx) {
101 let error_msg = match copilot.read(cx).status() {
102 Status::Disabled => anyhow::anyhow!("Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."),
103 Status::Error(e) => anyhow::anyhow!(format!("Received the following error while signing into Copilot: {e}")),
104 Status::Starting { task: _ } => anyhow::anyhow!("Copilot is still starting, please wait for Copilot to start then try again"),
105 Status::Unauthorized => anyhow::anyhow!("Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."),
106 Status::Authorized => return Task::ready(Ok(())),
107 Status::SignedOut => anyhow::anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again."),
108 Status::SigningIn { prompt: _ } => anyhow::anyhow!("Still signing into Copilot..."),
109 };
110 Err(error_msg)
111 } else {
112 Err(anyhow::anyhow!(
113 "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
114 ))
115 };
116 Task::ready(result)
117 }
118
119 fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
120 cx.new_view(|cx| AuthenticationPrompt::new(cx)).into()
121 }
122
123 fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
124 let Some(copilot) = Copilot::global(cx) else {
125 return Task::ready(Err(anyhow::anyhow!(
126 "Copilot is not available. Please ensure Copilot is enabled and running and try again."
127 )));
128 };
129
130 let state = self.state.clone();
131
132 cx.spawn(|mut cx| async move {
133 cx.update_model(&copilot, |model, cx| model.sign_out(cx))?
134 .await?;
135
136 cx.update_model(&state, |_, cx| {
137 cx.notify();
138 })?;
139
140 Ok(())
141 })
142 }
143}
144
145pub struct CopilotChatLanguageModel {
146 model: CopilotChatModel,
147}
148
149impl LanguageModel for CopilotChatLanguageModel {
150 fn id(&self) -> LanguageModelId {
151 LanguageModelId::from(self.model.id().to_string())
152 }
153
154 fn name(&self) -> LanguageModelName {
155 LanguageModelName::from(self.model.display_name().to_string())
156 }
157
158 fn provider_id(&self) -> LanguageModelProviderId {
159 LanguageModelProviderId(PROVIDER_ID.into())
160 }
161
162 fn provider_name(&self) -> LanguageModelProviderName {
163 LanguageModelProviderName(PROVIDER_NAME.into())
164 }
165
166 fn telemetry_id(&self) -> String {
167 format!("copilot_chat/{}", self.model.id())
168 }
169
170 fn max_token_count(&self) -> usize {
171 self.model.max_token_count()
172 }
173
174 fn count_tokens(
175 &self,
176 request: LanguageModelRequest,
177 cx: &AppContext,
178 ) -> BoxFuture<'static, Result<usize>> {
179 let model = match self.model {
180 CopilotChatModel::Gpt4 => open_ai::Model::Four,
181 CopilotChatModel::Gpt3_5Turbo => open_ai::Model::ThreePointFiveTurbo,
182 };
183
184 count_open_ai_tokens(request, model, cx)
185 }
186
187 fn stream_completion(
188 &self,
189 request: LanguageModelRequest,
190 cx: &AsyncAppContext,
191 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
192 if let Some(message) = request.messages.last() {
193 if message.content.trim().is_empty() {
194 const EMPTY_PROMPT_MSG: &str =
195 "Empty prompts aren't allowed. Please provide a non-empty prompt.";
196 return futures::future::ready(Err(anyhow::anyhow!(EMPTY_PROMPT_MSG))).boxed();
197 }
198
199 // Copilot Chat has a restriction that the final message must be from the user.
200 // While their API does return an error message for this, we can catch it earlier
201 // and provide a more helpful error message.
202 if !matches!(message.role, Role::User) {
203 const USER_ROLE_MSG: &str = "The final message must be from the user. To provide a system prompt, you must provide the system prompt followed by a user prompt.";
204 return futures::future::ready(Err(anyhow::anyhow!(USER_ROLE_MSG))).boxed();
205 }
206 }
207
208 let request = self.to_copilot_chat_request(request);
209 let Ok(low_speed_timeout) = cx.update(|cx| {
210 AllLanguageModelSettings::get_global(cx)
211 .copilot_chat
212 .low_speed_timeout
213 }) else {
214 return futures::future::ready(Err(anyhow::anyhow!("App state dropped"))).boxed();
215 };
216
217 cx.spawn(|mut cx| async move {
218 let response = CopilotChat::stream_completion(request, low_speed_timeout, &mut cx).await?;
219 let stream = response
220 .filter_map(|response| async move {
221 match response {
222 Ok(result) => {
223 let choice = result.choices.first();
224 match choice {
225 Some(choice) => Some(Ok(choice.delta.content.clone().unwrap_or_default())),
226 None => Some(Err(anyhow::anyhow!(
227 "The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
228 ))),
229 }
230 }
231 Err(err) => Some(Err(err)),
232 }
233 })
234 .boxed();
235 Ok(stream)
236 })
237 .boxed()
238 }
239}
240
241impl CopilotChatLanguageModel {
242 pub fn to_copilot_chat_request(&self, request: LanguageModelRequest) -> CopilotChatRequest {
243 CopilotChatRequest::new(
244 self.model.clone(),
245 request
246 .messages
247 .into_iter()
248 .map(|msg| ChatMessage {
249 role: match msg.role {
250 Role::User => CopilotChatRole::User,
251 Role::Assistant => CopilotChatRole::Assistant,
252 Role::System => CopilotChatRole::System,
253 },
254 content: msg.content,
255 })
256 .collect(),
257 )
258 }
259}
260
261struct AuthenticationPrompt {
262 copilot_status: Option<copilot::Status>,
263 _subscription: Option<Subscription>,
264}
265
266impl AuthenticationPrompt {
267 pub fn new(cx: &mut ViewContext<Self>) -> Self {
268 let copilot = Copilot::global(cx);
269
270 Self {
271 copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()),
272 _subscription: copilot.as_ref().map(|copilot| {
273 cx.observe(copilot, |this, model, cx| {
274 this.copilot_status = Some(model.read(cx).status());
275 cx.notify();
276 })
277 }),
278 }
279 }
280}
281
282impl Render for AuthenticationPrompt {
283 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
284 let loading_icon = svg()
285 .size_8()
286 .path(IconName::ArrowCircle.path())
287 .text_color(cx.text_style().color)
288 .with_animation(
289 "icon_circle_arrow",
290 Animation::new(Duration::from_secs(2)).repeat(),
291 |svg, delta| svg.with_transformation(Transformation::rotate(percentage(delta))),
292 );
293
294 const ERROR_LABEL: &str = "Copilot Chat requires the Copilot plugin to be available and running. Please ensure Copilot is running and try again, or use a different Assistant provider.";
295 match &self.copilot_status {
296 Some(status) => match status {
297 Status::Disabled => {
298 return v_flex().gap_6().p_4().child(Label::new(ERROR_LABEL));
299 }
300 Status::Starting { task: _ } => {
301 const LABEL: &str = "Starting Copilot...";
302 return v_flex()
303 .gap_6()
304 .p_4()
305 .justify_center()
306 .items_center()
307 .child(Label::new(LABEL))
308 .child(loading_icon);
309 }
310 Status::SigningIn { prompt: _ } => {
311 const LABEL: &str = "Signing in to Copilot...";
312 return v_flex()
313 .gap_6()
314 .p_4()
315 .justify_center()
316 .items_center()
317 .child(Label::new(LABEL))
318 .child(loading_icon);
319 }
320 Status::Error(_) => {
321 const LABEL: &str = "Copilot had issues starting. Please try restarting it. If the issue persists, try reinstalling Copilot.";
322 return v_flex()
323 .gap_6()
324 .p_4()
325 .child(Label::new(LABEL))
326 .child(svg().size_8().path(IconName::CopilotError.path()));
327 }
328 _ => {
329 const LABEL: &str =
330 "To use the assistant panel or inline assistant, you must login to GitHub Copilot. Your GitHub account must have an active Copilot Chat subscription.";
331 v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
332 v_flex()
333 .gap_2()
334 .child(
335 Button::new("sign_in", "Sign In")
336 .icon_color(Color::Muted)
337 .icon(IconName::Github)
338 .icon_position(IconPosition::Start)
339 .icon_size(IconSize::Medium)
340 .style(ui::ButtonStyle::Filled)
341 .full_width()
342 .on_click(|_, cx| {
343 inline_completion_button::initiate_sign_in(cx)
344 }),
345 )
346 .child(
347 div().flex().w_full().items_center().child(
348 Label::new("Sign in to start using Github Copilot Chat.")
349 .color(Color::Muted)
350 .size(ui::LabelSize::Small),
351 ),
352 ),
353 )
354 }
355 },
356 None => v_flex().gap_6().p_4().child(Label::new(ERROR_LABEL)),
357 }
358 }
359}