copilot_chat.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use copilot::copilot_chat::{
  5    ChatMessage, CopilotChat, Model as CopilotChatModel, Request as CopilotChatRequest,
  6    Role as CopilotChatRole,
  7};
  8use copilot::{Copilot, Status};
  9use futures::future::BoxFuture;
 10use futures::stream::BoxStream;
 11use futures::{FutureExt, StreamExt};
 12use gpui::{
 13    percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, Model,
 14    ModelContext, Render, Subscription, Task, Transformation,
 15};
 16use settings::{Settings, SettingsStore};
 17use std::time::Duration;
 18use strum::IntoEnumIterator;
 19use ui::{
 20    div, v_flex, Button, ButtonCommon, Clickable, Color, Context, FixedWidth, IconName,
 21    IconPosition, IconSize, IntoElement, Label, LabelCommon, ParentElement, Styled, ViewContext,
 22    VisualContext, WindowContext,
 23};
 24
 25use crate::settings::AllLanguageModelSettings;
 26use crate::LanguageModelProviderState;
 27use crate::{
 28    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
 29    LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, Role,
 30};
 31
 32use super::open_ai::count_open_ai_tokens;
 33
 34const PROVIDER_ID: &str = "copilot_chat";
 35const PROVIDER_NAME: &str = "GitHub Copilot Chat";
 36
 37#[derive(Default, Clone, Debug, PartialEq)]
 38pub struct CopilotChatSettings {
 39    pub low_speed_timeout: Option<Duration>,
 40}
 41
 42pub struct CopilotChatLanguageModelProvider {
 43    state: Model<State>,
 44}
 45
 46pub struct State {
 47    _copilot_chat_subscription: Option<Subscription>,
 48    _settings_subscription: Subscription,
 49}
 50
 51impl CopilotChatLanguageModelProvider {
 52    pub fn new(cx: &mut AppContext) -> Self {
 53        let state = cx.new_model(|cx| {
 54            let _copilot_chat_subscription = CopilotChat::global(cx)
 55                .map(|copilot_chat| cx.observe(&copilot_chat, |_, _, cx| cx.notify()));
 56            State {
 57                _copilot_chat_subscription,
 58                _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 59                    cx.notify();
 60                }),
 61            }
 62        });
 63
 64        Self { state }
 65    }
 66}
 67
 68impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
 69    fn subscribe<T: 'static>(&self, cx: &mut ModelContext<T>) -> Option<Subscription> {
 70        Some(cx.observe(&self.state, |_, _, cx| {
 71            cx.notify();
 72        }))
 73    }
 74}
 75
 76impl LanguageModelProvider for CopilotChatLanguageModelProvider {
 77    fn id(&self) -> LanguageModelProviderId {
 78        LanguageModelProviderId(PROVIDER_ID.into())
 79    }
 80
 81    fn name(&self) -> LanguageModelProviderName {
 82        LanguageModelProviderName(PROVIDER_NAME.into())
 83    }
 84
 85    fn provided_models(&self, _cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
 86        CopilotChatModel::iter()
 87            .map(|model| Arc::new(CopilotChatLanguageModel { model }) as Arc<dyn LanguageModel>)
 88            .collect()
 89    }
 90
 91    fn is_authenticated(&self, cx: &AppContext) -> bool {
 92        CopilotChat::global(cx)
 93            .map(|m| m.read(cx).is_authenticated())
 94            .unwrap_or(false)
 95    }
 96
 97    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 98        let result = if self.is_authenticated(cx) {
 99            Ok(())
100        } else if let Some(copilot) = Copilot::global(cx) {
101            let error_msg = match copilot.read(cx).status() {
102                Status::Disabled => anyhow::anyhow!("Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."),
103                Status::Error(e) => anyhow::anyhow!(format!("Received the following error while signing into Copilot: {e}")),
104                Status::Starting { task: _ } => anyhow::anyhow!("Copilot is still starting, please wait for Copilot to start then try again"),
105                Status::Unauthorized => anyhow::anyhow!("Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."),
106                Status::Authorized => return Task::ready(Ok(())),
107                Status::SignedOut => anyhow::anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again."),
108                Status::SigningIn { prompt: _ } => anyhow::anyhow!("Still signing into Copilot..."),
109            };
110            Err(error_msg)
111        } else {
112            Err(anyhow::anyhow!(
113                "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
114            ))
115        };
116        Task::ready(result)
117    }
118
119    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
120        cx.new_view(|cx| AuthenticationPrompt::new(cx)).into()
121    }
122
123    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
124        let Some(copilot) = Copilot::global(cx) else {
125            return Task::ready(Err(anyhow::anyhow!(
126                "Copilot is not available. Please ensure Copilot is enabled and running and try again."
127            )));
128        };
129
130        let state = self.state.clone();
131
132        cx.spawn(|mut cx| async move {
133            cx.update_model(&copilot, |model, cx| model.sign_out(cx))?
134                .await?;
135
136            cx.update_model(&state, |_, cx| {
137                cx.notify();
138            })?;
139
140            Ok(())
141        })
142    }
143}
144
145pub struct CopilotChatLanguageModel {
146    model: CopilotChatModel,
147}
148
149impl LanguageModel for CopilotChatLanguageModel {
150    fn id(&self) -> LanguageModelId {
151        LanguageModelId::from(self.model.id().to_string())
152    }
153
154    fn name(&self) -> LanguageModelName {
155        LanguageModelName::from(self.model.display_name().to_string())
156    }
157
158    fn provider_id(&self) -> LanguageModelProviderId {
159        LanguageModelProviderId(PROVIDER_ID.into())
160    }
161
162    fn provider_name(&self) -> LanguageModelProviderName {
163        LanguageModelProviderName(PROVIDER_NAME.into())
164    }
165
166    fn telemetry_id(&self) -> String {
167        format!("copilot_chat/{}", self.model.id())
168    }
169
170    fn max_token_count(&self) -> usize {
171        self.model.max_token_count()
172    }
173
174    fn count_tokens(
175        &self,
176        request: LanguageModelRequest,
177        cx: &AppContext,
178    ) -> BoxFuture<'static, Result<usize>> {
179        let model = match self.model {
180            CopilotChatModel::Gpt4 => open_ai::Model::Four,
181            CopilotChatModel::Gpt3_5Turbo => open_ai::Model::ThreePointFiveTurbo,
182        };
183
184        count_open_ai_tokens(request, model, cx)
185    }
186
187    fn stream_completion(
188        &self,
189        request: LanguageModelRequest,
190        cx: &AsyncAppContext,
191    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
192        if let Some(message) = request.messages.last() {
193            if message.content.trim().is_empty() {
194                const EMPTY_PROMPT_MSG: &str =
195                    "Empty prompts aren't allowed. Please provide a non-empty prompt.";
196                return futures::future::ready(Err(anyhow::anyhow!(EMPTY_PROMPT_MSG))).boxed();
197            }
198
199            // Copilot Chat has a restriction that the final message must be from the user.
200            // While their API does return an error message for this, we can catch it earlier
201            // and provide a more helpful error message.
202            if !matches!(message.role, Role::User) {
203                const USER_ROLE_MSG: &str = "The final message must be from the user. To provide a system prompt, you must provide the system prompt followed by a user prompt.";
204                return futures::future::ready(Err(anyhow::anyhow!(USER_ROLE_MSG))).boxed();
205            }
206        }
207
208        let request = self.to_copilot_chat_request(request);
209        let Ok(low_speed_timeout) = cx.update(|cx| {
210            AllLanguageModelSettings::get_global(cx)
211                .copilot_chat
212                .low_speed_timeout
213        }) else {
214            return futures::future::ready(Err(anyhow::anyhow!("App state dropped"))).boxed();
215        };
216
217        cx.spawn(|mut cx| async move {
218            let response = CopilotChat::stream_completion(request, low_speed_timeout, &mut cx).await?;
219            let stream = response
220                .filter_map(|response| async move {
221                    match response {
222                        Ok(result) => {
223                            let choice = result.choices.first();
224                            match choice {
225                                Some(choice) => Some(Ok(choice.delta.content.clone().unwrap_or_default())),
226                                None => Some(Err(anyhow::anyhow!(
227                                    "The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
228                                ))),
229                            }
230                        }
231                        Err(err) => Some(Err(err)),
232                    }
233                })
234                .boxed();
235            Ok(stream)
236        })
237        .boxed()
238    }
239}
240
241impl CopilotChatLanguageModel {
242    pub fn to_copilot_chat_request(&self, request: LanguageModelRequest) -> CopilotChatRequest {
243        CopilotChatRequest::new(
244            self.model.clone(),
245            request
246                .messages
247                .into_iter()
248                .map(|msg| ChatMessage {
249                    role: match msg.role {
250                        Role::User => CopilotChatRole::User,
251                        Role::Assistant => CopilotChatRole::Assistant,
252                        Role::System => CopilotChatRole::System,
253                    },
254                    content: msg.content,
255                })
256                .collect(),
257        )
258    }
259}
260
261struct AuthenticationPrompt {
262    copilot_status: Option<copilot::Status>,
263    _subscription: Option<Subscription>,
264}
265
266impl AuthenticationPrompt {
267    pub fn new(cx: &mut ViewContext<Self>) -> Self {
268        let copilot = Copilot::global(cx);
269
270        Self {
271            copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()),
272            _subscription: copilot.as_ref().map(|copilot| {
273                cx.observe(copilot, |this, model, cx| {
274                    this.copilot_status = Some(model.read(cx).status());
275                    cx.notify();
276                })
277            }),
278        }
279    }
280}
281
282impl Render for AuthenticationPrompt {
283    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
284        let loading_icon = svg()
285            .size_8()
286            .path(IconName::ArrowCircle.path())
287            .text_color(cx.text_style().color)
288            .with_animation(
289                "icon_circle_arrow",
290                Animation::new(Duration::from_secs(2)).repeat(),
291                |svg, delta| svg.with_transformation(Transformation::rotate(percentage(delta))),
292            );
293
294        const ERROR_LABEL: &str = "Copilot Chat requires the Copilot plugin to be available and running. Please ensure Copilot is running and try again, or use a different Assistant provider.";
295        match &self.copilot_status {
296            Some(status) => match status {
297                Status::Disabled => {
298                    return v_flex().gap_6().p_4().child(Label::new(ERROR_LABEL));
299                }
300                Status::Starting { task: _ } => {
301                    const LABEL: &str = "Starting Copilot...";
302                    return v_flex()
303                        .gap_6()
304                        .p_4()
305                        .justify_center()
306                        .items_center()
307                        .child(Label::new(LABEL))
308                        .child(loading_icon);
309                }
310                Status::SigningIn { prompt: _ } => {
311                    const LABEL: &str = "Signing in to Copilot...";
312                    return v_flex()
313                        .gap_6()
314                        .p_4()
315                        .justify_center()
316                        .items_center()
317                        .child(Label::new(LABEL))
318                        .child(loading_icon);
319                }
320                Status::Error(_) => {
321                    const LABEL: &str = "Copilot had issues starting. Please try restarting it. If the issue persists, try reinstalling Copilot.";
322                    return v_flex()
323                        .gap_6()
324                        .p_4()
325                        .child(Label::new(LABEL))
326                        .child(svg().size_8().path(IconName::CopilotError.path()));
327                }
328                _ => {
329                    const LABEL: &str =
330                    "To use the assistant panel or inline assistant, you must login to GitHub Copilot. Your GitHub account must have an active Copilot Chat subscription.";
331                    v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
332                        v_flex()
333                            .gap_2()
334                            .child(
335                                Button::new("sign_in", "Sign In")
336                                    .icon_color(Color::Muted)
337                                    .icon(IconName::Github)
338                                    .icon_position(IconPosition::Start)
339                                    .icon_size(IconSize::Medium)
340                                    .style(ui::ButtonStyle::Filled)
341                                    .full_width()
342                                    .on_click(|_, cx| {
343                                        inline_completion_button::initiate_sign_in(cx)
344                                    }),
345                            )
346                            .child(
347                                div().flex().w_full().items_center().child(
348                                    Label::new("Sign in to start using Github Copilot Chat.")
349                                        .color(Color::Muted)
350                                        .size(ui::LabelSize::Small),
351                                ),
352                            ),
353                    )
354                }
355            },
356            None => v_flex().gap_6().p_4().child(Label::new(ERROR_LABEL)),
357        }
358    }
359}