copilot_chat.rs

  1use std::future;
  2use std::sync::Arc;
  3
  4use anyhow::{anyhow, Result};
  5use copilot::copilot_chat::{
  6    ChatMessage, CopilotChat, Model as CopilotChatModel, Request as CopilotChatRequest,
  7    Role as CopilotChatRole,
  8};
  9use copilot::{Copilot, Status};
 10use futures::future::BoxFuture;
 11use futures::stream::BoxStream;
 12use futures::{FutureExt, StreamExt};
 13use gpui::{
 14    percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, Model, Render,
 15    Subscription, Task, Transformation,
 16};
 17use settings::{Settings, SettingsStore};
 18use std::time::Duration;
 19use strum::IntoEnumIterator;
 20use ui::{
 21    div, v_flex, Button, ButtonCommon, Clickable, Color, Context, FixedWidth, IconName,
 22    IconPosition, IconSize, IntoElement, Label, LabelCommon, ParentElement, Styled, ViewContext,
 23    VisualContext, WindowContext,
 24};
 25
 26use crate::settings::AllLanguageModelSettings;
 27use crate::LanguageModelProviderState;
 28use crate::{
 29    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
 30    LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, RateLimiter, Role,
 31};
 32
 33use super::open_ai::count_open_ai_tokens;
 34
 35const PROVIDER_ID: &str = "copilot_chat";
 36const PROVIDER_NAME: &str = "GitHub Copilot Chat";
 37
 38#[derive(Default, Clone, Debug, PartialEq)]
 39pub struct CopilotChatSettings {
 40    pub low_speed_timeout: Option<Duration>,
 41}
 42
 43pub struct CopilotChatLanguageModelProvider {
 44    state: Model<State>,
 45}
 46
 47pub struct State {
 48    _copilot_chat_subscription: Option<Subscription>,
 49    _settings_subscription: Subscription,
 50}
 51
 52impl CopilotChatLanguageModelProvider {
 53    pub fn new(cx: &mut AppContext) -> Self {
 54        let state = cx.new_model(|cx| {
 55            let _copilot_chat_subscription = CopilotChat::global(cx)
 56                .map(|copilot_chat| cx.observe(&copilot_chat, |_, _, cx| cx.notify()));
 57            State {
 58                _copilot_chat_subscription,
 59                _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 60                    cx.notify();
 61                }),
 62            }
 63        });
 64
 65        Self { state }
 66    }
 67}
 68
 69impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
 70    type ObservableEntity = State;
 71
 72    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
 73        Some(self.state.clone())
 74    }
 75}
 76
 77impl LanguageModelProvider for CopilotChatLanguageModelProvider {
 78    fn id(&self) -> LanguageModelProviderId {
 79        LanguageModelProviderId(PROVIDER_ID.into())
 80    }
 81
 82    fn name(&self) -> LanguageModelProviderName {
 83        LanguageModelProviderName(PROVIDER_NAME.into())
 84    }
 85
 86    fn provided_models(&self, _cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
 87        CopilotChatModel::iter()
 88            .map(|model| {
 89                Arc::new(CopilotChatLanguageModel {
 90                    model,
 91                    request_limiter: RateLimiter::new(4),
 92                }) as Arc<dyn LanguageModel>
 93            })
 94            .collect()
 95    }
 96
 97    fn is_authenticated(&self, cx: &AppContext) -> bool {
 98        CopilotChat::global(cx)
 99            .map(|m| m.read(cx).is_authenticated())
100            .unwrap_or(false)
101    }
102
103    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
104        let result = if self.is_authenticated(cx) {
105            Ok(())
106        } else if let Some(copilot) = Copilot::global(cx) {
107            let error_msg = match copilot.read(cx).status() {
108                Status::Disabled => anyhow::anyhow!("Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."),
109                Status::Error(e) => anyhow::anyhow!(format!("Received the following error while signing into Copilot: {e}")),
110                Status::Starting { task: _ } => anyhow::anyhow!("Copilot is still starting, please wait for Copilot to start then try again"),
111                Status::Unauthorized => anyhow::anyhow!("Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."),
112                Status::Authorized => return Task::ready(Ok(())),
113                Status::SignedOut => anyhow::anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again."),
114                Status::SigningIn { prompt: _ } => anyhow::anyhow!("Still signing into Copilot..."),
115            };
116            Err(error_msg)
117        } else {
118            Err(anyhow::anyhow!(
119                "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
120            ))
121        };
122        Task::ready(result)
123    }
124
125    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
126        cx.new_view(|cx| AuthenticationPrompt::new(cx)).into()
127    }
128
129    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
130        let Some(copilot) = Copilot::global(cx) else {
131            return Task::ready(Err(anyhow::anyhow!(
132                "Copilot is not available. Please ensure Copilot is enabled and running and try again."
133            )));
134        };
135
136        let state = self.state.clone();
137
138        cx.spawn(|mut cx| async move {
139            cx.update_model(&copilot, |model, cx| model.sign_out(cx))?
140                .await?;
141
142            cx.update_model(&state, |_, cx| {
143                cx.notify();
144            })?;
145
146            Ok(())
147        })
148    }
149}
150
151pub struct CopilotChatLanguageModel {
152    model: CopilotChatModel,
153    request_limiter: RateLimiter,
154}
155
156impl LanguageModel for CopilotChatLanguageModel {
157    fn id(&self) -> LanguageModelId {
158        LanguageModelId::from(self.model.id().to_string())
159    }
160
161    fn name(&self) -> LanguageModelName {
162        LanguageModelName::from(self.model.display_name().to_string())
163    }
164
165    fn provider_id(&self) -> LanguageModelProviderId {
166        LanguageModelProviderId(PROVIDER_ID.into())
167    }
168
169    fn provider_name(&self) -> LanguageModelProviderName {
170        LanguageModelProviderName(PROVIDER_NAME.into())
171    }
172
173    fn telemetry_id(&self) -> String {
174        format!("copilot_chat/{}", self.model.id())
175    }
176
177    fn max_token_count(&self) -> usize {
178        self.model.max_token_count()
179    }
180
181    fn count_tokens(
182        &self,
183        request: LanguageModelRequest,
184        cx: &AppContext,
185    ) -> BoxFuture<'static, Result<usize>> {
186        let model = match self.model {
187            CopilotChatModel::Gpt4 => open_ai::Model::Four,
188            CopilotChatModel::Gpt3_5Turbo => open_ai::Model::ThreePointFiveTurbo,
189        };
190
191        count_open_ai_tokens(request, model, cx)
192    }
193
194    fn stream_completion(
195        &self,
196        request: LanguageModelRequest,
197        cx: &AsyncAppContext,
198    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
199        if let Some(message) = request.messages.last() {
200            if message.content.trim().is_empty() {
201                const EMPTY_PROMPT_MSG: &str =
202                    "Empty prompts aren't allowed. Please provide a non-empty prompt.";
203                return futures::future::ready(Err(anyhow::anyhow!(EMPTY_PROMPT_MSG))).boxed();
204            }
205
206            // Copilot Chat has a restriction that the final message must be from the user.
207            // While their API does return an error message for this, we can catch it earlier
208            // and provide a more helpful error message.
209            if !matches!(message.role, Role::User) {
210                const USER_ROLE_MSG: &str = "The final message must be from the user. To provide a system prompt, you must provide the system prompt followed by a user prompt.";
211                return futures::future::ready(Err(anyhow::anyhow!(USER_ROLE_MSG))).boxed();
212            }
213        }
214
215        let request = self.to_copilot_chat_request(request);
216        let Ok(low_speed_timeout) = cx.update(|cx| {
217            AllLanguageModelSettings::get_global(cx)
218                .copilot_chat
219                .low_speed_timeout
220        }) else {
221            return futures::future::ready(Err(anyhow::anyhow!("App state dropped"))).boxed();
222        };
223
224        let request_limiter = self.request_limiter.clone();
225        let future = cx.spawn(|cx| async move {
226            let response = CopilotChat::stream_completion(request, low_speed_timeout, cx);
227            request_limiter.stream(async move {
228                let response = response.await?;
229                let stream = response
230                    .filter_map(|response| async move {
231                        match response {
232                            Ok(result) => {
233                                let choice = result.choices.first();
234                                match choice {
235                                    Some(choice) => Some(Ok(choice.delta.content.clone().unwrap_or_default())),
236                                    None => Some(Err(anyhow::anyhow!(
237                                        "The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
238                                    ))),
239                                }
240                            }
241                            Err(err) => Some(Err(err)),
242                        }
243                    })
244                    .boxed();
245                Ok(stream)
246            }).await
247        });
248
249        async move { Ok(future.await?.boxed()) }.boxed()
250    }
251
252    fn use_any_tool(
253        &self,
254        _request: LanguageModelRequest,
255        _name: String,
256        _description: String,
257        _schema: serde_json::Value,
258        _cx: &AsyncAppContext,
259    ) -> BoxFuture<'static, Result<serde_json::Value>> {
260        future::ready(Err(anyhow!("not implemented"))).boxed()
261    }
262}
263
264impl CopilotChatLanguageModel {
265    pub fn to_copilot_chat_request(&self, request: LanguageModelRequest) -> CopilotChatRequest {
266        CopilotChatRequest::new(
267            self.model.clone(),
268            request
269                .messages
270                .into_iter()
271                .map(|msg| ChatMessage {
272                    role: match msg.role {
273                        Role::User => CopilotChatRole::User,
274                        Role::Assistant => CopilotChatRole::Assistant,
275                        Role::System => CopilotChatRole::System,
276                    },
277                    content: msg.content,
278                })
279                .collect(),
280        )
281    }
282}
283
284struct AuthenticationPrompt {
285    copilot_status: Option<copilot::Status>,
286    _subscription: Option<Subscription>,
287}
288
289impl AuthenticationPrompt {
290    pub fn new(cx: &mut ViewContext<Self>) -> Self {
291        let copilot = Copilot::global(cx);
292
293        Self {
294            copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()),
295            _subscription: copilot.as_ref().map(|copilot| {
296                cx.observe(copilot, |this, model, cx| {
297                    this.copilot_status = Some(model.read(cx).status());
298                    cx.notify();
299                })
300            }),
301        }
302    }
303}
304
305impl Render for AuthenticationPrompt {
306    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
307        let loading_icon = svg()
308            .size_8()
309            .path(IconName::ArrowCircle.path())
310            .text_color(cx.text_style().color)
311            .with_animation(
312                "icon_circle_arrow",
313                Animation::new(Duration::from_secs(2)).repeat(),
314                |svg, delta| svg.with_transformation(Transformation::rotate(percentage(delta))),
315            );
316
317        const ERROR_LABEL: &str = "Copilot Chat requires the Copilot plugin to be available and running. Please ensure Copilot is running and try again, or use a different Assistant provider.";
318        match &self.copilot_status {
319            Some(status) => match status {
320                Status::Disabled => {
321                    return v_flex().gap_6().p_4().child(Label::new(ERROR_LABEL));
322                }
323                Status::Starting { task: _ } => {
324                    const LABEL: &str = "Starting Copilot...";
325                    return v_flex()
326                        .gap_6()
327                        .p_4()
328                        .justify_center()
329                        .items_center()
330                        .child(Label::new(LABEL))
331                        .child(loading_icon);
332                }
333                Status::SigningIn { prompt: _ } => {
334                    const LABEL: &str = "Signing in to Copilot...";
335                    return v_flex()
336                        .gap_6()
337                        .p_4()
338                        .justify_center()
339                        .items_center()
340                        .child(Label::new(LABEL))
341                        .child(loading_icon);
342                }
343                Status::Error(_) => {
344                    const LABEL: &str = "Copilot had issues starting. Please try restarting it. If the issue persists, try reinstalling Copilot.";
345                    return v_flex()
346                        .gap_6()
347                        .p_4()
348                        .child(Label::new(LABEL))
349                        .child(svg().size_8().path(IconName::CopilotError.path()));
350                }
351                _ => {
352                    const LABEL: &str =
353                    "To use the assistant panel or inline assistant, you must login to GitHub Copilot. Your GitHub account must have an active Copilot Chat subscription.";
354                    v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
355                        v_flex()
356                            .gap_2()
357                            .child(
358                                Button::new("sign_in", "Sign In")
359                                    .icon_color(Color::Muted)
360                                    .icon(IconName::Github)
361                                    .icon_position(IconPosition::Start)
362                                    .icon_size(IconSize::Medium)
363                                    .style(ui::ButtonStyle::Filled)
364                                    .full_width()
365                                    .on_click(|_, cx| {
366                                        inline_completion_button::initiate_sign_in(cx)
367                                    }),
368                            )
369                            .child(
370                                div().flex().w_full().items_center().child(
371                                    Label::new("Sign in to start using Github Copilot Chat.")
372                                        .color(Color::Muted)
373                                        .size(ui::LabelSize::Small),
374                                ),
375                            ),
376                    )
377                }
378            },
379            None => v_flex().gap_6().p_4().child(Label::new(ERROR_LABEL)),
380        }
381    }
382}