copilot_chat.rs

  1use std::future;
  2use std::sync::Arc;
  3
  4use anyhow::{anyhow, Result};
  5use copilot::copilot_chat::{
  6    ChatMessage, CopilotChat, Model as CopilotChatModel, Request as CopilotChatRequest,
  7    Role as CopilotChatRole,
  8};
  9use copilot::{Copilot, Status};
 10use futures::future::BoxFuture;
 11use futures::stream::BoxStream;
 12use futures::{FutureExt, StreamExt};
 13use gpui::{
 14    percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, Model,
 15    ModelContext, Render, Subscription, Task, Transformation,
 16};
 17use settings::{Settings, SettingsStore};
 18use std::time::Duration;
 19use strum::IntoEnumIterator;
 20use ui::{
 21    div, v_flex, Button, ButtonCommon, Clickable, Color, Context, FixedWidth, IconName,
 22    IconPosition, IconSize, IntoElement, Label, LabelCommon, ParentElement, Styled, ViewContext,
 23    VisualContext, WindowContext,
 24};
 25
 26use crate::settings::AllLanguageModelSettings;
 27use crate::LanguageModelProviderState;
 28use crate::{
 29    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
 30    LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, Role,
 31};
 32
 33use super::open_ai::count_open_ai_tokens;
 34
 35const PROVIDER_ID: &str = "copilot_chat";
 36const PROVIDER_NAME: &str = "GitHub Copilot Chat";
 37
 38#[derive(Default, Clone, Debug, PartialEq)]
 39pub struct CopilotChatSettings {
 40    pub low_speed_timeout: Option<Duration>,
 41}
 42
 43pub struct CopilotChatLanguageModelProvider {
 44    state: Model<State>,
 45}
 46
 47pub struct State {
 48    _copilot_chat_subscription: Option<Subscription>,
 49    _settings_subscription: Subscription,
 50}
 51
 52impl CopilotChatLanguageModelProvider {
 53    pub fn new(cx: &mut AppContext) -> Self {
 54        let state = cx.new_model(|cx| {
 55            let _copilot_chat_subscription = CopilotChat::global(cx)
 56                .map(|copilot_chat| cx.observe(&copilot_chat, |_, _, cx| cx.notify()));
 57            State {
 58                _copilot_chat_subscription,
 59                _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 60                    cx.notify();
 61                }),
 62            }
 63        });
 64
 65        Self { state }
 66    }
 67}
 68
 69impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
 70    fn subscribe<T: 'static>(&self, cx: &mut ModelContext<T>) -> Option<Subscription> {
 71        Some(cx.observe(&self.state, |_, _, cx| {
 72            cx.notify();
 73        }))
 74    }
 75}
 76
 77impl LanguageModelProvider for CopilotChatLanguageModelProvider {
 78    fn id(&self) -> LanguageModelProviderId {
 79        LanguageModelProviderId(PROVIDER_ID.into())
 80    }
 81
 82    fn name(&self) -> LanguageModelProviderName {
 83        LanguageModelProviderName(PROVIDER_NAME.into())
 84    }
 85
 86    fn provided_models(&self, _cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
 87        CopilotChatModel::iter()
 88            .map(|model| Arc::new(CopilotChatLanguageModel { model }) as Arc<dyn LanguageModel>)
 89            .collect()
 90    }
 91
 92    fn is_authenticated(&self, cx: &AppContext) -> bool {
 93        CopilotChat::global(cx)
 94            .map(|m| m.read(cx).is_authenticated())
 95            .unwrap_or(false)
 96    }
 97
 98    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 99        let result = if self.is_authenticated(cx) {
100            Ok(())
101        } else if let Some(copilot) = Copilot::global(cx) {
102            let error_msg = match copilot.read(cx).status() {
103                Status::Disabled => anyhow::anyhow!("Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."),
104                Status::Error(e) => anyhow::anyhow!(format!("Received the following error while signing into Copilot: {e}")),
105                Status::Starting { task: _ } => anyhow::anyhow!("Copilot is still starting, please wait for Copilot to start then try again"),
106                Status::Unauthorized => anyhow::anyhow!("Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."),
107                Status::Authorized => return Task::ready(Ok(())),
108                Status::SignedOut => anyhow::anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again."),
109                Status::SigningIn { prompt: _ } => anyhow::anyhow!("Still signing into Copilot..."),
110            };
111            Err(error_msg)
112        } else {
113            Err(anyhow::anyhow!(
114                "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
115            ))
116        };
117        Task::ready(result)
118    }
119
120    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
121        cx.new_view(|cx| AuthenticationPrompt::new(cx)).into()
122    }
123
124    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
125        let Some(copilot) = Copilot::global(cx) else {
126            return Task::ready(Err(anyhow::anyhow!(
127                "Copilot is not available. Please ensure Copilot is enabled and running and try again."
128            )));
129        };
130
131        let state = self.state.clone();
132
133        cx.spawn(|mut cx| async move {
134            cx.update_model(&copilot, |model, cx| model.sign_out(cx))?
135                .await?;
136
137            cx.update_model(&state, |_, cx| {
138                cx.notify();
139            })?;
140
141            Ok(())
142        })
143    }
144}
145
146pub struct CopilotChatLanguageModel {
147    model: CopilotChatModel,
148}
149
150impl LanguageModel for CopilotChatLanguageModel {
151    fn id(&self) -> LanguageModelId {
152        LanguageModelId::from(self.model.id().to_string())
153    }
154
155    fn name(&self) -> LanguageModelName {
156        LanguageModelName::from(self.model.display_name().to_string())
157    }
158
159    fn provider_id(&self) -> LanguageModelProviderId {
160        LanguageModelProviderId(PROVIDER_ID.into())
161    }
162
163    fn provider_name(&self) -> LanguageModelProviderName {
164        LanguageModelProviderName(PROVIDER_NAME.into())
165    }
166
167    fn telemetry_id(&self) -> String {
168        format!("copilot_chat/{}", self.model.id())
169    }
170
171    fn max_token_count(&self) -> usize {
172        self.model.max_token_count()
173    }
174
175    fn count_tokens(
176        &self,
177        request: LanguageModelRequest,
178        cx: &AppContext,
179    ) -> BoxFuture<'static, Result<usize>> {
180        let model = match self.model {
181            CopilotChatModel::Gpt4 => open_ai::Model::Four,
182            CopilotChatModel::Gpt3_5Turbo => open_ai::Model::ThreePointFiveTurbo,
183        };
184
185        count_open_ai_tokens(request, model, cx)
186    }
187
188    fn stream_completion(
189        &self,
190        request: LanguageModelRequest,
191        cx: &AsyncAppContext,
192    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
193        if let Some(message) = request.messages.last() {
194            if message.content.trim().is_empty() {
195                const EMPTY_PROMPT_MSG: &str =
196                    "Empty prompts aren't allowed. Please provide a non-empty prompt.";
197                return futures::future::ready(Err(anyhow::anyhow!(EMPTY_PROMPT_MSG))).boxed();
198            }
199
200            // Copilot Chat has a restriction that the final message must be from the user.
201            // While their API does return an error message for this, we can catch it earlier
202            // and provide a more helpful error message.
203            if !matches!(message.role, Role::User) {
204                const USER_ROLE_MSG: &str = "The final message must be from the user. To provide a system prompt, you must provide the system prompt followed by a user prompt.";
205                return futures::future::ready(Err(anyhow::anyhow!(USER_ROLE_MSG))).boxed();
206            }
207        }
208
209        let request = self.to_copilot_chat_request(request);
210        let Ok(low_speed_timeout) = cx.update(|cx| {
211            AllLanguageModelSettings::get_global(cx)
212                .copilot_chat
213                .low_speed_timeout
214        }) else {
215            return futures::future::ready(Err(anyhow::anyhow!("App state dropped"))).boxed();
216        };
217
218        cx.spawn(|mut cx| async move {
219            let response = CopilotChat::stream_completion(request, low_speed_timeout, &mut cx).await?;
220            let stream = response
221                .filter_map(|response| async move {
222                    match response {
223                        Ok(result) => {
224                            let choice = result.choices.first();
225                            match choice {
226                                Some(choice) => Some(Ok(choice.delta.content.clone().unwrap_or_default())),
227                                None => Some(Err(anyhow::anyhow!(
228                                    "The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
229                                ))),
230                            }
231                        }
232                        Err(err) => Some(Err(err)),
233                    }
234                })
235                .boxed();
236            Ok(stream)
237        })
238        .boxed()
239    }
240
241    fn use_tool(
242        &self,
243        _request: LanguageModelRequest,
244        _name: String,
245        _description: String,
246        _schema: serde_json::Value,
247        _cx: &AsyncAppContext,
248    ) -> BoxFuture<'static, Result<serde_json::Value>> {
249        future::ready(Err(anyhow!("not implemented"))).boxed()
250    }
251}
252
253impl CopilotChatLanguageModel {
254    pub fn to_copilot_chat_request(&self, request: LanguageModelRequest) -> CopilotChatRequest {
255        CopilotChatRequest::new(
256            self.model.clone(),
257            request
258                .messages
259                .into_iter()
260                .map(|msg| ChatMessage {
261                    role: match msg.role {
262                        Role::User => CopilotChatRole::User,
263                        Role::Assistant => CopilotChatRole::Assistant,
264                        Role::System => CopilotChatRole::System,
265                    },
266                    content: msg.content,
267                })
268                .collect(),
269        )
270    }
271}
272
273struct AuthenticationPrompt {
274    copilot_status: Option<copilot::Status>,
275    _subscription: Option<Subscription>,
276}
277
278impl AuthenticationPrompt {
279    pub fn new(cx: &mut ViewContext<Self>) -> Self {
280        let copilot = Copilot::global(cx);
281
282        Self {
283            copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()),
284            _subscription: copilot.as_ref().map(|copilot| {
285                cx.observe(copilot, |this, model, cx| {
286                    this.copilot_status = Some(model.read(cx).status());
287                    cx.notify();
288                })
289            }),
290        }
291    }
292}
293
294impl Render for AuthenticationPrompt {
295    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
296        let loading_icon = svg()
297            .size_8()
298            .path(IconName::ArrowCircle.path())
299            .text_color(cx.text_style().color)
300            .with_animation(
301                "icon_circle_arrow",
302                Animation::new(Duration::from_secs(2)).repeat(),
303                |svg, delta| svg.with_transformation(Transformation::rotate(percentage(delta))),
304            );
305
306        const ERROR_LABEL: &str = "Copilot Chat requires the Copilot plugin to be available and running. Please ensure Copilot is running and try again, or use a different Assistant provider.";
307        match &self.copilot_status {
308            Some(status) => match status {
309                Status::Disabled => {
310                    return v_flex().gap_6().p_4().child(Label::new(ERROR_LABEL));
311                }
312                Status::Starting { task: _ } => {
313                    const LABEL: &str = "Starting Copilot...";
314                    return v_flex()
315                        .gap_6()
316                        .p_4()
317                        .justify_center()
318                        .items_center()
319                        .child(Label::new(LABEL))
320                        .child(loading_icon);
321                }
322                Status::SigningIn { prompt: _ } => {
323                    const LABEL: &str = "Signing in to Copilot...";
324                    return v_flex()
325                        .gap_6()
326                        .p_4()
327                        .justify_center()
328                        .items_center()
329                        .child(Label::new(LABEL))
330                        .child(loading_icon);
331                }
332                Status::Error(_) => {
333                    const LABEL: &str = "Copilot had issues starting. Please try restarting it. If the issue persists, try reinstalling Copilot.";
334                    return v_flex()
335                        .gap_6()
336                        .p_4()
337                        .child(Label::new(LABEL))
338                        .child(svg().size_8().path(IconName::CopilotError.path()));
339                }
340                _ => {
341                    const LABEL: &str =
342                    "To use the assistant panel or inline assistant, you must login to GitHub Copilot. Your GitHub account must have an active Copilot Chat subscription.";
343                    v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
344                        v_flex()
345                            .gap_2()
346                            .child(
347                                Button::new("sign_in", "Sign In")
348                                    .icon_color(Color::Muted)
349                                    .icon(IconName::Github)
350                                    .icon_position(IconPosition::Start)
351                                    .icon_size(IconSize::Medium)
352                                    .style(ui::ButtonStyle::Filled)
353                                    .full_width()
354                                    .on_click(|_, cx| {
355                                        inline_completion_button::initiate_sign_in(cx)
356                                    }),
357                            )
358                            .child(
359                                div().flex().w_full().items_center().child(
360                                    Label::new("Sign in to start using Github Copilot Chat.")
361                                        .color(Color::Muted)
362                                        .size(ui::LabelSize::Small),
363                                ),
364                            ),
365                    )
366                }
367            },
368            None => v_flex().gap_6().p_4().child(Label::new(ERROR_LABEL)),
369        }
370    }
371}