copilot_chat.rs

  1use std::future;
  2use std::sync::Arc;
  3
  4use anyhow::{anyhow, Result};
  5use copilot::copilot_chat::{
  6    ChatMessage, CopilotChat, Model as CopilotChatModel, Request as CopilotChatRequest,
  7    Role as CopilotChatRole,
  8};
  9use copilot::{Copilot, Status};
 10use futures::future::BoxFuture;
 11use futures::stream::BoxStream;
 12use futures::{FutureExt, StreamExt};
 13use gpui::{
 14    percentage, svg, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, Render, Subscription,
 15    Task, Transformation,
 16};
 17use language_model::{
 18    LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
 19    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
 20    LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
 21};
 22use settings::SettingsStore;
 23use std::time::Duration;
 24use strum::IntoEnumIterator;
 25use ui::prelude::*;
 26
 27use super::anthropic::count_anthropic_tokens;
 28use super::open_ai::count_open_ai_tokens;
 29
 30const PROVIDER_ID: &str = "copilot_chat";
 31const PROVIDER_NAME: &str = "GitHub Copilot Chat";
 32
 33#[derive(Default, Clone, Debug, PartialEq)]
 34pub struct CopilotChatSettings {}
 35
 36pub struct CopilotChatLanguageModelProvider {
 37    state: Entity<State>,
 38}
 39
 40pub struct State {
 41    _copilot_chat_subscription: Option<Subscription>,
 42    _settings_subscription: Subscription,
 43}
 44
 45impl State {
 46    fn is_authenticated(&self, cx: &App) -> bool {
 47        CopilotChat::global(cx)
 48            .map(|m| m.read(cx).is_authenticated())
 49            .unwrap_or(false)
 50    }
 51}
 52
 53impl CopilotChatLanguageModelProvider {
 54    pub fn new(cx: &mut App) -> Self {
 55        let state = cx.new(|cx| {
 56            let _copilot_chat_subscription = CopilotChat::global(cx)
 57                .map(|copilot_chat| cx.observe(&copilot_chat, |_, _, cx| cx.notify()));
 58            State {
 59                _copilot_chat_subscription,
 60                _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 61                    cx.notify();
 62                }),
 63            }
 64        });
 65
 66        Self { state }
 67    }
 68}
 69
 70impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
 71    type ObservableEntity = State;
 72
 73    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
 74        Some(self.state.clone())
 75    }
 76}
 77
 78impl LanguageModelProvider for CopilotChatLanguageModelProvider {
 79    fn id(&self) -> LanguageModelProviderId {
 80        LanguageModelProviderId(PROVIDER_ID.into())
 81    }
 82
 83    fn name(&self) -> LanguageModelProviderName {
 84        LanguageModelProviderName(PROVIDER_NAME.into())
 85    }
 86
 87    fn icon(&self) -> IconName {
 88        IconName::Copilot
 89    }
 90
 91    fn provided_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 92        CopilotChatModel::iter()
 93            .map(|model| {
 94                Arc::new(CopilotChatLanguageModel {
 95                    model,
 96                    request_limiter: RateLimiter::new(4),
 97                }) as Arc<dyn LanguageModel>
 98            })
 99            .collect()
100    }
101
102    fn is_authenticated(&self, cx: &App) -> bool {
103        self.state.read(cx).is_authenticated(cx)
104    }
105
106    fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
107        let result = if self.is_authenticated(cx) {
108            Ok(())
109        } else if let Some(copilot) = Copilot::global(cx) {
110            let error_msg = match copilot.read(cx).status() {
111                Status::Disabled => anyhow::anyhow!("Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."),
112                Status::Error(e) => anyhow::anyhow!(format!("Received the following error while signing into Copilot: {e}")),
113                Status::Starting { task: _ } => anyhow::anyhow!("Copilot is still starting, please wait for Copilot to start then try again"),
114                Status::Unauthorized => anyhow::anyhow!("Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."),
115                Status::Authorized => return Task::ready(Ok(())),
116                Status::SignedOut => anyhow::anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again."),
117                Status::SigningIn { prompt: _ } => anyhow::anyhow!("Still signing into Copilot..."),
118            };
119            Err(error_msg)
120        } else {
121            Err(anyhow::anyhow!(
122                "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
123            ))
124        };
125        Task::ready(result)
126    }
127
128    fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
129        let state = self.state.clone();
130        cx.new(|cx| ConfigurationView::new(state, cx)).into()
131    }
132
133    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
134        Task::ready(Err(anyhow!(
135            "Signing out of GitHub Copilot Chat is currently not supported."
136        )))
137    }
138}
139
140pub struct CopilotChatLanguageModel {
141    model: CopilotChatModel,
142    request_limiter: RateLimiter,
143}
144
145impl LanguageModel for CopilotChatLanguageModel {
146    fn id(&self) -> LanguageModelId {
147        LanguageModelId::from(self.model.id().to_string())
148    }
149
150    fn name(&self) -> LanguageModelName {
151        LanguageModelName::from(self.model.display_name().to_string())
152    }
153
154    fn provider_id(&self) -> LanguageModelProviderId {
155        LanguageModelProviderId(PROVIDER_ID.into())
156    }
157
158    fn provider_name(&self) -> LanguageModelProviderName {
159        LanguageModelProviderName(PROVIDER_NAME.into())
160    }
161
162    fn telemetry_id(&self) -> String {
163        format!("copilot_chat/{}", self.model.id())
164    }
165
166    fn max_token_count(&self) -> usize {
167        self.model.max_token_count()
168    }
169
170    fn count_tokens(
171        &self,
172        request: LanguageModelRequest,
173        cx: &App,
174    ) -> BoxFuture<'static, Result<usize>> {
175        match self.model {
176            CopilotChatModel::Claude3_5Sonnet => count_anthropic_tokens(request, cx),
177            _ => {
178                let model = match self.model {
179                    CopilotChatModel::Gpt4o => open_ai::Model::FourOmni,
180                    CopilotChatModel::Gpt4 => open_ai::Model::Four,
181                    CopilotChatModel::Gpt3_5Turbo => open_ai::Model::ThreePointFiveTurbo,
182                    CopilotChatModel::O1 | CopilotChatModel::O3Mini => open_ai::Model::Four,
183                    CopilotChatModel::Claude3_5Sonnet => unreachable!(),
184                };
185                count_open_ai_tokens(request, model, cx)
186            }
187        }
188    }
189
190    fn stream_completion(
191        &self,
192        request: LanguageModelRequest,
193        cx: &AsyncApp,
194    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
195        if let Some(message) = request.messages.last() {
196            if message.contents_empty() {
197                const EMPTY_PROMPT_MSG: &str =
198                    "Empty prompts aren't allowed. Please provide a non-empty prompt.";
199                return futures::future::ready(Err(anyhow::anyhow!(EMPTY_PROMPT_MSG))).boxed();
200            }
201
202            // Copilot Chat has a restriction that the final message must be from the user.
203            // While their API does return an error message for this, we can catch it earlier
204            // and provide a more helpful error message.
205            if !matches!(message.role, Role::User) {
206                const USER_ROLE_MSG: &str = "The final message must be from the user. To provide a system prompt, you must provide the system prompt followed by a user prompt.";
207                return futures::future::ready(Err(anyhow::anyhow!(USER_ROLE_MSG))).boxed();
208            }
209        }
210
211        let copilot_request = self.to_copilot_chat_request(request);
212        let is_streaming = copilot_request.stream;
213
214        let request_limiter = self.request_limiter.clone();
215        let future = cx.spawn(|cx| async move {
216            let response = CopilotChat::stream_completion(copilot_request, cx);
217            request_limiter.stream(async move {
218                let response = response.await?;
219                let stream = response
220                    .filter_map(move |response| async move {
221                        match response {
222                            Ok(result) => {
223                                let choice = result.choices.first();
224                                match choice {
225                                    Some(choice) if !is_streaming => {
226                                        match &choice.message {
227                                            Some(msg) => Some(Ok(msg.content.clone().unwrap_or_default())),
228                                            None => Some(Err(anyhow::anyhow!(
229                                                "The Copilot Chat API returned a response with no message content"
230                                            ))),
231                                        }
232                                    },
233                                    Some(choice) => {
234                                        match &choice.delta {
235                                            Some(delta) => Some(Ok(delta.content.clone().unwrap_or_default())),
236                                            None => Some(Err(anyhow::anyhow!(
237                                                "The Copilot Chat API returned a response with no delta content"
238                                            ))),
239                                        }
240                                    },
241                                    None => Some(Err(anyhow::anyhow!(
242                                        "The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
243                                    ))),
244                                }
245                            }
246                            Err(err) => Some(Err(err)),
247                        }
248                    })
249                    .boxed();
250                Ok(stream)
251            }).await
252        });
253
254        async move {
255            Ok(future
256                .await?
257                .map(|result| result.map(LanguageModelCompletionEvent::Text))
258                .boxed())
259        }
260        .boxed()
261    }
262
263    fn use_any_tool(
264        &self,
265        _request: LanguageModelRequest,
266        _name: String,
267        _description: String,
268        _schema: serde_json::Value,
269        _cx: &AsyncApp,
270    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
271        future::ready(Err(anyhow!("not implemented"))).boxed()
272    }
273}
274
275impl CopilotChatLanguageModel {
276    pub fn to_copilot_chat_request(&self, request: LanguageModelRequest) -> CopilotChatRequest {
277        CopilotChatRequest::new(
278            self.model.clone(),
279            request
280                .messages
281                .into_iter()
282                .map(|msg| ChatMessage {
283                    role: match msg.role {
284                        Role::User => CopilotChatRole::User,
285                        Role::Assistant => CopilotChatRole::Assistant,
286                        Role::System => CopilotChatRole::System,
287                    },
288                    content: msg.string_contents(),
289                })
290                .collect(),
291        )
292    }
293}
294
295struct ConfigurationView {
296    copilot_status: Option<copilot::Status>,
297    state: Entity<State>,
298    _subscription: Option<Subscription>,
299}
300
301impl ConfigurationView {
302    pub fn new(state: Entity<State>, cx: &mut Context<Self>) -> Self {
303        let copilot = Copilot::global(cx);
304
305        Self {
306            copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()),
307            state,
308            _subscription: copilot.as_ref().map(|copilot| {
309                cx.observe(copilot, |this, model, cx| {
310                    this.copilot_status = Some(model.read(cx).status());
311                    cx.notify();
312                })
313            }),
314        }
315    }
316}
317
318impl Render for ConfigurationView {
319    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
320        if self.state.read(cx).is_authenticated(cx) {
321            const LABEL: &str = "Authorized.";
322            h_flex()
323                .gap_1()
324                .child(Icon::new(IconName::Check).color(Color::Success))
325                .child(Label::new(LABEL))
326        } else {
327            let loading_icon = svg()
328                .size_8()
329                .path(IconName::ArrowCircle.path())
330                .text_color(window.text_style().color)
331                .with_animation(
332                    "icon_circle_arrow",
333                    Animation::new(Duration::from_secs(2)).repeat(),
334                    |svg, delta| svg.with_transformation(Transformation::rotate(percentage(delta))),
335                );
336
337            const ERROR_LABEL: &str = "Copilot Chat requires an active GitHub Copilot subscription. Please ensure Copilot is configured and try again, or use a different Assistant provider.";
338
339            match &self.copilot_status {
340                Some(status) => match status {
341                    Status::Disabled => v_flex().gap_6().p_4().child(Label::new(ERROR_LABEL)),
342                    Status::Starting { task: _ } => {
343                        const LABEL: &str = "Starting Copilot...";
344                        v_flex()
345                            .gap_6()
346                            .justify_center()
347                            .items_center()
348                            .child(Label::new(LABEL))
349                            .child(loading_icon)
350                    }
351                    Status::SigningIn { prompt: _ } => {
352                        const LABEL: &str = "Signing in to Copilot...";
353                        v_flex()
354                            .gap_6()
355                            .justify_center()
356                            .items_center()
357                            .child(Label::new(LABEL))
358                            .child(loading_icon)
359                    }
360                    Status::Error(_) => {
361                        const LABEL: &str = "Copilot had issues starting. Please try restarting it. If the issue persists, try reinstalling Copilot.";
362                        v_flex()
363                            .gap_6()
364                            .child(Label::new(LABEL))
365                            .child(svg().size_8().path(IconName::CopilotError.path()))
366                    }
367                    _ => {
368                        const LABEL: &str =
369                    "To use Zed's assistant with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription.";
370                        v_flex().gap_6().child(Label::new(LABEL)).child(
371                            v_flex()
372                                .gap_2()
373                                .child(
374                                    Button::new("sign_in", "Sign In")
375                                        .icon_color(Color::Muted)
376                                        .icon(IconName::Github)
377                                        .icon_position(IconPosition::Start)
378                                        .icon_size(IconSize::Medium)
379                                        .style(ui::ButtonStyle::Filled)
380                                        .full_width()
381                                        .on_click(|_, window, cx| {
382                                            copilot::initiate_sign_in(window, cx)
383                                        }),
384                                )
385                                .child(
386                                    div().flex().w_full().items_center().child(
387                                        Label::new("Sign in to start using Github Copilot Chat.")
388                                            .color(Color::Muted)
389                                            .size(ui::LabelSize::Small),
390                                    ),
391                                ),
392                        )
393                    }
394                },
395                None => v_flex().gap_6().child(Label::new(ERROR_LABEL)),
396            }
397        }
398    }
399}