copilot_chat.rs

  1use std::pin::Pin;
  2use std::str::FromStr as _;
  3use std::sync::Arc;
  4
  5use anyhow::{Result, anyhow};
  6use collections::HashMap;
  7use copilot::copilot_chat::{
  8    ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, ImageUrl,
  9    Model as CopilotChatModel, ModelVendor, Request as CopilotChatRequest, ResponseEvent, Tool,
 10    ToolCall,
 11};
 12use copilot::{Copilot, Status};
 13use futures::future::BoxFuture;
 14use futures::stream::BoxStream;
 15use futures::{FutureExt, Stream, StreamExt};
 16use gpui::{
 17    Action, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, Render, Subscription, Task,
 18    Transformation, percentage, svg,
 19};
 20use language::language_settings::all_language_settings;
 21use language_model::{
 22    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
 23    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
 24    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
 25    LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent,
 26    LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role,
 27    StopReason, TokenUsage,
 28};
 29use settings::SettingsStore;
 30use std::time::Duration;
 31use ui::prelude::*;
 32use util::debug_panic;
 33
 34use super::anthropic::count_anthropic_tokens;
 35use super::google::count_google_tokens;
 36use super::open_ai::count_open_ai_tokens;
 37
 38const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat");
 39const PROVIDER_NAME: LanguageModelProviderName =
 40    LanguageModelProviderName::new("GitHub Copilot Chat");
 41
 42pub struct CopilotChatLanguageModelProvider {
 43    state: Entity<State>,
 44}
 45
 46pub struct State {
 47    _copilot_chat_subscription: Option<Subscription>,
 48    _settings_subscription: Subscription,
 49}
 50
 51impl State {
 52    fn is_authenticated(&self, cx: &App) -> bool {
 53        CopilotChat::global(cx)
 54            .map(|m| m.read(cx).is_authenticated())
 55            .unwrap_or(false)
 56    }
 57}
 58
 59impl CopilotChatLanguageModelProvider {
 60    pub fn new(cx: &mut App) -> Self {
 61        let state = cx.new(|cx| {
 62            let copilot_chat_subscription = CopilotChat::global(cx)
 63                .map(|copilot_chat| cx.observe(&copilot_chat, |_, _, cx| cx.notify()));
 64            State {
 65                _copilot_chat_subscription: copilot_chat_subscription,
 66                _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 67                    if let Some(copilot_chat) = CopilotChat::global(cx) {
 68                        let language_settings = all_language_settings(None, cx);
 69                        let configuration = copilot::copilot_chat::CopilotChatConfiguration {
 70                            enterprise_uri: language_settings
 71                                .edit_predictions
 72                                .copilot
 73                                .enterprise_uri
 74                                .clone(),
 75                        };
 76                        copilot_chat.update(cx, |chat, cx| {
 77                            chat.set_configuration(configuration, cx);
 78                        });
 79                    }
 80                    cx.notify();
 81                }),
 82            }
 83        });
 84
 85        Self { state }
 86    }
 87
 88    fn create_language_model(&self, model: CopilotChatModel) -> Arc<dyn LanguageModel> {
 89        Arc::new(CopilotChatLanguageModel {
 90            model,
 91            request_limiter: RateLimiter::new(4),
 92        })
 93    }
 94}
 95
 96impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
 97    type ObservableEntity = State;
 98
 99    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
100        Some(self.state.clone())
101    }
102}
103
104impl LanguageModelProvider for CopilotChatLanguageModelProvider {
105    fn id(&self) -> LanguageModelProviderId {
106        PROVIDER_ID
107    }
108
109    fn name(&self) -> LanguageModelProviderName {
110        PROVIDER_NAME
111    }
112
113    fn icon(&self) -> IconName {
114        IconName::Copilot
115    }
116
117    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
118        let models = CopilotChat::global(cx).and_then(|m| m.read(cx).models())?;
119        models
120            .first()
121            .map(|model| self.create_language_model(model.clone()))
122    }
123
124    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
125        // The default model should be Copilot Chat's 'base model', which is likely a relatively fast
126        // model (e.g. 4o) and a sensible choice when considering premium requests
127        self.default_model(cx)
128    }
129
130    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
131        let Some(models) = CopilotChat::global(cx).and_then(|m| m.read(cx).models()) else {
132            return Vec::new();
133        };
134        models
135            .iter()
136            .map(|model| self.create_language_model(model.clone()))
137            .collect()
138    }
139
140    fn is_authenticated(&self, cx: &App) -> bool {
141        self.state.read(cx).is_authenticated(cx)
142    }
143
144    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
145        if self.is_authenticated(cx) {
146            return Task::ready(Ok(()));
147        };
148
149        let Some(copilot) = Copilot::global(cx) else {
150            return Task::ready( Err(anyhow!(
151                "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
152            ).into()));
153        };
154
155        let err = match copilot.read(cx).status() {
156            Status::Authorized => return Task::ready(Ok(())),
157            Status::Disabled => anyhow!(
158                "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
159            ),
160            Status::Error(err) => anyhow!(format!(
161                "Received the following error while signing into Copilot: {err}"
162            )),
163            Status::Starting { task: _ } => anyhow!(
164                "Copilot is still starting, please wait for Copilot to start then try again"
165            ),
166            Status::Unauthorized => anyhow!(
167                "Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."
168            ),
169            Status::SignedOut { .. } => {
170                anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again.")
171            }
172            Status::SigningIn { prompt: _ } => anyhow!("Still signing into Copilot..."),
173        };
174
175        Task::ready(Err(err.into()))
176    }
177
178    fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
179        let state = self.state.clone();
180        cx.new(|cx| ConfigurationView::new(state, cx)).into()
181    }
182
183    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
184        Task::ready(Err(anyhow!(
185            "Signing out of GitHub Copilot Chat is currently not supported."
186        )))
187    }
188}
189
190pub struct CopilotChatLanguageModel {
191    model: CopilotChatModel,
192    request_limiter: RateLimiter,
193}
194
195impl LanguageModel for CopilotChatLanguageModel {
196    fn id(&self) -> LanguageModelId {
197        LanguageModelId::from(self.model.id().to_string())
198    }
199
200    fn name(&self) -> LanguageModelName {
201        LanguageModelName::from(self.model.display_name().to_string())
202    }
203
204    fn provider_id(&self) -> LanguageModelProviderId {
205        PROVIDER_ID
206    }
207
208    fn provider_name(&self) -> LanguageModelProviderName {
209        PROVIDER_NAME
210    }
211
212    fn supports_tools(&self) -> bool {
213        self.model.supports_tools()
214    }
215
216    fn supports_images(&self) -> bool {
217        self.model.supports_vision()
218    }
219
220    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
221        match self.model.vendor() {
222            ModelVendor::OpenAI | ModelVendor::Anthropic => {
223                LanguageModelToolSchemaFormat::JsonSchema
224            }
225            ModelVendor::Google => LanguageModelToolSchemaFormat::JsonSchemaSubset,
226        }
227    }
228
229    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
230        match choice {
231            LanguageModelToolChoice::Auto
232            | LanguageModelToolChoice::Any
233            | LanguageModelToolChoice::None => self.supports_tools(),
234        }
235    }
236
237    fn telemetry_id(&self) -> String {
238        format!("copilot_chat/{}", self.model.id())
239    }
240
241    fn max_token_count(&self) -> u64 {
242        self.model.max_token_count()
243    }
244
245    fn count_tokens(
246        &self,
247        request: LanguageModelRequest,
248        cx: &App,
249    ) -> BoxFuture<'static, Result<u64>> {
250        match self.model.vendor() {
251            ModelVendor::Anthropic => count_anthropic_tokens(request, cx),
252            ModelVendor::Google => count_google_tokens(request, cx),
253            ModelVendor::OpenAI => {
254                let model = open_ai::Model::from_id(self.model.id()).unwrap_or_default();
255                count_open_ai_tokens(request, model, cx)
256            }
257        }
258    }
259
260    fn stream_completion(
261        &self,
262        request: LanguageModelRequest,
263        cx: &AsyncApp,
264    ) -> BoxFuture<
265        'static,
266        Result<
267            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
268            LanguageModelCompletionError,
269        >,
270    > {
271        let copilot_request = match into_copilot_chat(&self.model, request) {
272            Ok(request) => request,
273            Err(err) => return futures::future::ready(Err(err.into())).boxed(),
274        };
275        let is_streaming = copilot_request.stream;
276
277        let request_limiter = self.request_limiter.clone();
278        let future = cx.spawn(async move |cx| {
279            let request = CopilotChat::stream_completion(copilot_request, cx.clone());
280            request_limiter
281                .stream(async move {
282                    let response = request.await?;
283                    Ok(map_to_language_model_completion_events(
284                        response,
285                        is_streaming,
286                    ))
287                })
288                .await
289        });
290        async move { Ok(future.await?.boxed()) }.boxed()
291    }
292}
293
294pub fn map_to_language_model_completion_events(
295    events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
296    is_streaming: bool,
297) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
298    #[derive(Default)]
299    struct RawToolCall {
300        id: String,
301        name: String,
302        arguments: String,
303    }
304
305    struct State {
306        events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
307        tool_calls_by_index: HashMap<usize, RawToolCall>,
308    }
309
310    futures::stream::unfold(
311        State {
312            events,
313            tool_calls_by_index: HashMap::default(),
314        },
315        move |mut state| async move {
316            if let Some(event) = state.events.next().await {
317                match event {
318                    Ok(event) => {
319                        let Some(choice) = event.choices.first() else {
320                            return Some((
321                                vec![Err(anyhow!("Response contained no choices").into())],
322                                state,
323                            ));
324                        };
325
326                        let delta = if is_streaming {
327                            choice.delta.as_ref()
328                        } else {
329                            choice.message.as_ref()
330                        };
331
332                        let Some(delta) = delta else {
333                            return Some((
334                                vec![Err(anyhow!("Response contained no delta").into())],
335                                state,
336                            ));
337                        };
338
339                        let mut events = Vec::new();
340                        if let Some(content) = delta.content.clone() {
341                            events.push(Ok(LanguageModelCompletionEvent::Text(content)));
342                        }
343
344                        for tool_call in &delta.tool_calls {
345                            let entry = state
346                                .tool_calls_by_index
347                                .entry(tool_call.index)
348                                .or_default();
349
350                            if let Some(tool_id) = tool_call.id.clone() {
351                                entry.id = tool_id;
352                            }
353
354                            if let Some(function) = tool_call.function.as_ref() {
355                                if let Some(name) = function.name.clone() {
356                                    entry.name = name;
357                                }
358
359                                if let Some(arguments) = function.arguments.clone() {
360                                    entry.arguments.push_str(&arguments);
361                                }
362                            }
363                        }
364
365                        if let Some(usage) = event.usage {
366                            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
367                                TokenUsage {
368                                    input_tokens: usage.prompt_tokens,
369                                    output_tokens: usage.completion_tokens,
370                                    cache_creation_input_tokens: 0,
371                                    cache_read_input_tokens: 0,
372                                },
373                            )));
374                        }
375
376                        match choice.finish_reason.as_deref() {
377                            Some("stop") => {
378                                events.push(Ok(LanguageModelCompletionEvent::Stop(
379                                    StopReason::EndTurn,
380                                )));
381                            }
382                            Some("tool_calls") => {
383                                events.extend(state.tool_calls_by_index.drain().map(
384                                    |(_, tool_call)| {
385                                        // The model can output an empty string
386                                        // to indicate the absence of arguments.
387                                        // When that happens, create an empty
388                                        // object instead.
389                                        let arguments = if tool_call.arguments.is_empty() {
390                                            Ok(serde_json::Value::Object(Default::default()))
391                                        } else {
392                                            serde_json::Value::from_str(&tool_call.arguments)
393                                        };
394                                        match arguments {
395                                        Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
396                                            LanguageModelToolUse {
397                                                id: tool_call.id.clone().into(),
398                                                name: tool_call.name.as_str().into(),
399                                                is_input_complete: true,
400                                                input,
401                                                raw_input: tool_call.arguments.clone(),
402                                            },
403                                        )),
404                                        Err(error) => Ok(
405                                            LanguageModelCompletionEvent::ToolUseJsonParseError {
406                                                id: tool_call.id.into(),
407                                                tool_name: tool_call.name.as_str().into(),
408                                                raw_input: tool_call.arguments.into(),
409                                                json_parse_error: error.to_string(),
410                                            },
411                                        ),
412                                    }
413                                    },
414                                ));
415
416                                events.push(Ok(LanguageModelCompletionEvent::Stop(
417                                    StopReason::ToolUse,
418                                )));
419                            }
420                            Some(stop_reason) => {
421                                log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}");
422                                events.push(Ok(LanguageModelCompletionEvent::Stop(
423                                    StopReason::EndTurn,
424                                )));
425                            }
426                            None => {}
427                        }
428
429                        return Some((events, state));
430                    }
431                    Err(err) => return Some((vec![Err(anyhow!(err).into())], state)),
432                }
433            }
434
435            None
436        },
437    )
438    .flat_map(futures::stream::iter)
439}
440
441fn into_copilot_chat(
442    model: &copilot::copilot_chat::Model,
443    request: LanguageModelRequest,
444) -> Result<CopilotChatRequest> {
445    let mut request_messages: Vec<LanguageModelRequestMessage> = Vec::new();
446    for message in request.messages {
447        if let Some(last_message) = request_messages.last_mut() {
448            if last_message.role == message.role {
449                last_message.content.extend(message.content);
450            } else {
451                request_messages.push(message);
452            }
453        } else {
454            request_messages.push(message);
455        }
456    }
457
458    let mut tool_called = false;
459    let mut messages: Vec<ChatMessage> = Vec::new();
460    for message in request_messages {
461        match message.role {
462            Role::User => {
463                for content in &message.content {
464                    if let MessageContent::ToolResult(tool_result) = content {
465                        let content = match &tool_result.content {
466                            LanguageModelToolResultContent::Text(text) => text.to_string().into(),
467                            LanguageModelToolResultContent::Image(image) => {
468                                if model.supports_vision() {
469                                    ChatMessageContent::Multipart(vec![ChatMessagePart::Image {
470                                        image_url: ImageUrl {
471                                            url: image.to_base64_url(),
472                                        },
473                                    }])
474                                } else {
475                                    debug_panic!(
476                                        "This should be caught at {} level",
477                                        tool_result.tool_name
478                                    );
479                                    "[Tool responded with an image, but this model does not support vision]".to_string().into()
480                                }
481                            }
482                        };
483
484                        messages.push(ChatMessage::Tool {
485                            tool_call_id: tool_result.tool_use_id.to_string(),
486                            content,
487                        });
488                    }
489                }
490
491                let mut content_parts = Vec::new();
492                for content in &message.content {
493                    match content {
494                        MessageContent::Text(text) | MessageContent::Thinking { text, .. }
495                            if !text.is_empty() =>
496                        {
497                            if let Some(ChatMessagePart::Text { text: text_content }) =
498                                content_parts.last_mut()
499                            {
500                                text_content.push_str(text);
501                            } else {
502                                content_parts.push(ChatMessagePart::Text {
503                                    text: text.to_string(),
504                                });
505                            }
506                        }
507                        MessageContent::Image(image) if model.supports_vision() => {
508                            content_parts.push(ChatMessagePart::Image {
509                                image_url: ImageUrl {
510                                    url: image.to_base64_url(),
511                                },
512                            });
513                        }
514                        _ => {}
515                    }
516                }
517
518                if !content_parts.is_empty() {
519                    messages.push(ChatMessage::User {
520                        content: content_parts.into(),
521                    });
522                }
523            }
524            Role::Assistant => {
525                let mut tool_calls = Vec::new();
526                for content in &message.content {
527                    if let MessageContent::ToolUse(tool_use) = content {
528                        tool_called = true;
529                        tool_calls.push(ToolCall {
530                            id: tool_use.id.to_string(),
531                            content: copilot::copilot_chat::ToolCallContent::Function {
532                                function: copilot::copilot_chat::FunctionContent {
533                                    name: tool_use.name.to_string(),
534                                    arguments: serde_json::to_string(&tool_use.input)?,
535                                },
536                            },
537                        });
538                    }
539                }
540
541                let text_content = {
542                    let mut buffer = String::new();
543                    for string in message.content.iter().filter_map(|content| match content {
544                        MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
545                            Some(text.as_str())
546                        }
547                        MessageContent::ToolUse(_)
548                        | MessageContent::RedactedThinking(_)
549                        | MessageContent::ToolResult(_)
550                        | MessageContent::Image(_) => None,
551                    }) {
552                        buffer.push_str(string);
553                    }
554
555                    buffer
556                };
557
558                messages.push(ChatMessage::Assistant {
559                    content: if text_content.is_empty() {
560                        ChatMessageContent::empty()
561                    } else {
562                        text_content.into()
563                    },
564                    tool_calls,
565                });
566            }
567            Role::System => messages.push(ChatMessage::System {
568                content: message.string_contents(),
569            }),
570        }
571    }
572
573    let mut tools = request
574        .tools
575        .iter()
576        .map(|tool| Tool::Function {
577            function: copilot::copilot_chat::Function {
578                name: tool.name.clone(),
579                description: tool.description.clone(),
580                parameters: tool.input_schema.clone(),
581            },
582        })
583        .collect::<Vec<_>>();
584
585    // The API will return a Bad Request (with no error message) when tools
586    // were used previously in the conversation but no tools are provided as
587    // part of this request. Inserting a dummy tool seems to circumvent this
588    // error.
589    if tool_called && tools.is_empty() {
590        tools.push(Tool::Function {
591            function: copilot::copilot_chat::Function {
592                name: "noop".to_string(),
593                description: "No operation".to_string(),
594                parameters: serde_json::json!({
595                    "type": "object"
596                }),
597            },
598        });
599    }
600
601    Ok(CopilotChatRequest {
602        intent: true,
603        n: 1,
604        stream: model.uses_streaming(),
605        temperature: 0.1,
606        model: model.id().to_string(),
607        messages,
608        tools,
609        tool_choice: request.tool_choice.map(|choice| match choice {
610            LanguageModelToolChoice::Auto => copilot::copilot_chat::ToolChoice::Auto,
611            LanguageModelToolChoice::Any => copilot::copilot_chat::ToolChoice::Any,
612            LanguageModelToolChoice::None => copilot::copilot_chat::ToolChoice::None,
613        }),
614    })
615}
616
617struct ConfigurationView {
618    copilot_status: Option<copilot::Status>,
619    state: Entity<State>,
620    _subscription: Option<Subscription>,
621}
622
623impl ConfigurationView {
624    pub fn new(state: Entity<State>, cx: &mut Context<Self>) -> Self {
625        let copilot = Copilot::global(cx);
626
627        Self {
628            copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()),
629            state,
630            _subscription: copilot.as_ref().map(|copilot| {
631                cx.observe(copilot, |this, model, cx| {
632                    this.copilot_status = Some(model.read(cx).status());
633                    cx.notify();
634                })
635            }),
636        }
637    }
638}
639
640impl Render for ConfigurationView {
641    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
642        if self.state.read(cx).is_authenticated(cx) {
643            h_flex()
644                .mt_1()
645                .p_1()
646                .justify_between()
647                .rounded_md()
648                .border_1()
649                .border_color(cx.theme().colors().border)
650                .bg(cx.theme().colors().background)
651                .child(
652                    h_flex()
653                        .gap_1()
654                        .child(Icon::new(IconName::Check).color(Color::Success))
655                        .child(Label::new("Authorized")),
656                )
657                .child(
658                    Button::new("sign_out", "Sign Out")
659                        .label_size(LabelSize::Small)
660                        .on_click(|_, window, cx| {
661                            window.dispatch_action(copilot::SignOut.boxed_clone(), cx);
662                        }),
663                )
664        } else {
665            let loading_icon = Icon::new(IconName::ArrowCircle).with_animation(
666                "arrow-circle",
667                Animation::new(Duration::from_secs(4)).repeat(),
668                |icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
669            );
670
671            const ERROR_LABEL: &str = "Copilot Chat requires an active GitHub Copilot subscription. Please ensure Copilot is configured and try again, or use a different Assistant provider.";
672
673            match &self.copilot_status {
674                Some(status) => match status {
675                    Status::Starting { task: _ } => h_flex()
676                        .gap_2()
677                        .child(loading_icon)
678                        .child(Label::new("Starting Copilot…")),
679                    Status::SigningIn { prompt: _ }
680                    | Status::SignedOut {
681                        awaiting_signing_in: true,
682                    } => h_flex()
683                        .gap_2()
684                        .child(loading_icon)
685                        .child(Label::new("Signing into Copilot…")),
686                    Status::Error(_) => {
687                        const LABEL: &str = "Copilot had issues starting. Please try restarting it. If the issue persists, try reinstalling Copilot.";
688                        v_flex()
689                            .gap_6()
690                            .child(Label::new(LABEL))
691                            .child(svg().size_8().path(IconName::CopilotError.path()))
692                    }
693                    _ => {
694                        const LABEL: &str = "To use Zed's assistant with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription.";
695                        v_flex().gap_2().child(Label::new(LABEL)).child(
696                            Button::new("sign_in", "Sign in to use GitHub Copilot")
697                                .icon_color(Color::Muted)
698                                .icon(IconName::Github)
699                                .icon_position(IconPosition::Start)
700                                .icon_size(IconSize::Medium)
701                                .full_width()
702                                .on_click(|_, window, cx| copilot::initiate_sign_in(window, cx)),
703                        )
704                    }
705                },
706                None => v_flex().gap_6().child(Label::new(ERROR_LABEL)),
707            }
708        }
709    }
710}