copilot_chat.rs

  1use std::pin::Pin;
  2use std::str::FromStr as _;
  3use std::sync::Arc;
  4
  5use anyhow::{Result, anyhow};
  6use cloud_llm_client::CompletionIntent;
  7use collections::HashMap;
  8use copilot::copilot_chat::{
  9    ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, ImageUrl,
 10    Model as CopilotChatModel, ModelVendor, Request as CopilotChatRequest, ResponseEvent, Tool,
 11    ToolCall,
 12};
 13use copilot::{Copilot, Status};
 14use futures::future::BoxFuture;
 15use futures::stream::BoxStream;
 16use futures::{FutureExt, Stream, StreamExt};
 17use gpui::{
 18    Action, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, Render, Subscription, Task,
 19    Transformation, percentage, svg,
 20};
 21use language::language_settings::all_language_settings;
 22use language_model::{
 23    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
 24    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
 25    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
 26    LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent,
 27    LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role,
 28    StopReason, TokenUsage,
 29};
 30use settings::SettingsStore;
 31use std::time::Duration;
 32use ui::prelude::*;
 33use util::debug_panic;
 34
 35use super::anthropic::count_anthropic_tokens;
 36use super::google::count_google_tokens;
 37use super::open_ai::count_open_ai_tokens;
 38
 39const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat");
 40const PROVIDER_NAME: LanguageModelProviderName =
 41    LanguageModelProviderName::new("GitHub Copilot Chat");
 42
 43pub struct CopilotChatLanguageModelProvider {
 44    state: Entity<State>,
 45}
 46
 47pub struct State {
 48    _copilot_chat_subscription: Option<Subscription>,
 49    _settings_subscription: Subscription,
 50}
 51
 52impl State {
 53    fn is_authenticated(&self, cx: &App) -> bool {
 54        CopilotChat::global(cx)
 55            .map(|m| m.read(cx).is_authenticated())
 56            .unwrap_or(false)
 57    }
 58}
 59
 60impl CopilotChatLanguageModelProvider {
 61    pub fn new(cx: &mut App) -> Self {
 62        let state = cx.new(|cx| {
 63            let copilot_chat_subscription = CopilotChat::global(cx)
 64                .map(|copilot_chat| cx.observe(&copilot_chat, |_, _, cx| cx.notify()));
 65            State {
 66                _copilot_chat_subscription: copilot_chat_subscription,
 67                _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 68                    if let Some(copilot_chat) = CopilotChat::global(cx) {
 69                        let language_settings = all_language_settings(None, cx);
 70                        let configuration = copilot::copilot_chat::CopilotChatConfiguration {
 71                            enterprise_uri: language_settings
 72                                .edit_predictions
 73                                .copilot
 74                                .enterprise_uri
 75                                .clone(),
 76                        };
 77                        copilot_chat.update(cx, |chat, cx| {
 78                            chat.set_configuration(configuration, cx);
 79                        });
 80                    }
 81                    cx.notify();
 82                }),
 83            }
 84        });
 85
 86        Self { state }
 87    }
 88
 89    fn create_language_model(&self, model: CopilotChatModel) -> Arc<dyn LanguageModel> {
 90        Arc::new(CopilotChatLanguageModel {
 91            model,
 92            request_limiter: RateLimiter::new(4),
 93        })
 94    }
 95}
 96
 97impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
 98    type ObservableEntity = State;
 99
100    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
101        Some(self.state.clone())
102    }
103}
104
105impl LanguageModelProvider for CopilotChatLanguageModelProvider {
106    fn id(&self) -> LanguageModelProviderId {
107        PROVIDER_ID
108    }
109
110    fn name(&self) -> LanguageModelProviderName {
111        PROVIDER_NAME
112    }
113
114    fn icon(&self) -> IconName {
115        IconName::Copilot
116    }
117
118    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
119        let models = CopilotChat::global(cx).and_then(|m| m.read(cx).models())?;
120        models
121            .first()
122            .map(|model| self.create_language_model(model.clone()))
123    }
124
125    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
126        // The default model should be Copilot Chat's 'base model', which is likely a relatively fast
127        // model (e.g. 4o) and a sensible choice when considering premium requests
128        self.default_model(cx)
129    }
130
131    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
132        let Some(models) = CopilotChat::global(cx).and_then(|m| m.read(cx).models()) else {
133            return Vec::new();
134        };
135        models
136            .iter()
137            .map(|model| self.create_language_model(model.clone()))
138            .collect()
139    }
140
141    fn is_authenticated(&self, cx: &App) -> bool {
142        self.state.read(cx).is_authenticated(cx)
143    }
144
145    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
146        if self.is_authenticated(cx) {
147            return Task::ready(Ok(()));
148        };
149
150        let Some(copilot) = Copilot::global(cx) else {
151            return Task::ready( Err(anyhow!(
152                "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
153            ).into()));
154        };
155
156        let err = match copilot.read(cx).status() {
157            Status::Authorized => return Task::ready(Ok(())),
158            Status::Disabled => anyhow!(
159                "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
160            ),
161            Status::Error(err) => anyhow!(format!(
162                "Received the following error while signing into Copilot: {err}"
163            )),
164            Status::Starting { task: _ } => anyhow!(
165                "Copilot is still starting, please wait for Copilot to start then try again"
166            ),
167            Status::Unauthorized => anyhow!(
168                "Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."
169            ),
170            Status::SignedOut { .. } => {
171                anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again.")
172            }
173            Status::SigningIn { prompt: _ } => anyhow!("Still signing into Copilot..."),
174        };
175
176        Task::ready(Err(err.into()))
177    }
178
179    fn configuration_view(
180        &self,
181        _target_agent: language_model::ConfigurationViewTargetAgent,
182        _: &mut Window,
183        cx: &mut App,
184    ) -> AnyView {
185        let state = self.state.clone();
186        cx.new(|cx| ConfigurationView::new(state, cx)).into()
187    }
188
189    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
190        Task::ready(Err(anyhow!(
191            "Signing out of GitHub Copilot Chat is currently not supported."
192        )))
193    }
194}
195
196pub struct CopilotChatLanguageModel {
197    model: CopilotChatModel,
198    request_limiter: RateLimiter,
199}
200
201impl LanguageModel for CopilotChatLanguageModel {
202    fn id(&self) -> LanguageModelId {
203        LanguageModelId::from(self.model.id().to_string())
204    }
205
206    fn name(&self) -> LanguageModelName {
207        LanguageModelName::from(self.model.display_name().to_string())
208    }
209
210    fn provider_id(&self) -> LanguageModelProviderId {
211        PROVIDER_ID
212    }
213
214    fn provider_name(&self) -> LanguageModelProviderName {
215        PROVIDER_NAME
216    }
217
218    fn supports_tools(&self) -> bool {
219        self.model.supports_tools()
220    }
221
222    fn supports_images(&self) -> bool {
223        self.model.supports_vision()
224    }
225
226    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
227        match self.model.vendor() {
228            ModelVendor::OpenAI | ModelVendor::Anthropic => {
229                LanguageModelToolSchemaFormat::JsonSchema
230            }
231            ModelVendor::Google => LanguageModelToolSchemaFormat::JsonSchemaSubset,
232        }
233    }
234
235    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
236        match choice {
237            LanguageModelToolChoice::Auto
238            | LanguageModelToolChoice::Any
239            | LanguageModelToolChoice::None => self.supports_tools(),
240        }
241    }
242
243    fn telemetry_id(&self) -> String {
244        format!("copilot_chat/{}", self.model.id())
245    }
246
247    fn max_token_count(&self) -> u64 {
248        self.model.max_token_count()
249    }
250
251    fn count_tokens(
252        &self,
253        request: LanguageModelRequest,
254        cx: &App,
255    ) -> BoxFuture<'static, Result<u64>> {
256        match self.model.vendor() {
257            ModelVendor::Anthropic => count_anthropic_tokens(request, cx),
258            ModelVendor::Google => count_google_tokens(request, cx),
259            ModelVendor::OpenAI => {
260                let model = open_ai::Model::from_id(self.model.id()).unwrap_or_default();
261                count_open_ai_tokens(request, model, cx)
262            }
263        }
264    }
265
266    fn stream_completion(
267        &self,
268        request: LanguageModelRequest,
269        cx: &AsyncApp,
270    ) -> BoxFuture<
271        'static,
272        Result<
273            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
274            LanguageModelCompletionError,
275        >,
276    > {
277        let is_user_initiated = request.intent.is_none_or(|intent| match intent {
278            CompletionIntent::UserPrompt
279            | CompletionIntent::ThreadContextSummarization
280            | CompletionIntent::InlineAssist
281            | CompletionIntent::TerminalInlineAssist
282            | CompletionIntent::GenerateGitCommitMessage => true,
283
284            CompletionIntent::ToolResults
285            | CompletionIntent::ThreadSummarization
286            | CompletionIntent::CreateFile
287            | CompletionIntent::EditFile => false,
288        });
289
290        let copilot_request = match into_copilot_chat(&self.model, request) {
291            Ok(request) => request,
292            Err(err) => return futures::future::ready(Err(err.into())).boxed(),
293        };
294        let is_streaming = copilot_request.stream;
295
296        let request_limiter = self.request_limiter.clone();
297        let future = cx.spawn(async move |cx| {
298            let request =
299                CopilotChat::stream_completion(copilot_request, is_user_initiated, cx.clone());
300            request_limiter
301                .stream(async move {
302                    let response = request.await?;
303                    Ok(map_to_language_model_completion_events(
304                        response,
305                        is_streaming,
306                    ))
307                })
308                .await
309        });
310        async move { Ok(future.await?.boxed()) }.boxed()
311    }
312}
313
314pub fn map_to_language_model_completion_events(
315    events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
316    is_streaming: bool,
317) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
318    #[derive(Default)]
319    struct RawToolCall {
320        id: String,
321        name: String,
322        arguments: String,
323    }
324
325    struct State {
326        events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
327        tool_calls_by_index: HashMap<usize, RawToolCall>,
328    }
329
330    futures::stream::unfold(
331        State {
332            events,
333            tool_calls_by_index: HashMap::default(),
334        },
335        move |mut state| async move {
336            if let Some(event) = state.events.next().await {
337                match event {
338                    Ok(event) => {
339                        let Some(choice) = event.choices.first() else {
340                            return Some((
341                                vec![Err(anyhow!("Response contained no choices").into())],
342                                state,
343                            ));
344                        };
345
346                        let delta = if is_streaming {
347                            choice.delta.as_ref()
348                        } else {
349                            choice.message.as_ref()
350                        };
351
352                        let Some(delta) = delta else {
353                            return Some((
354                                vec![Err(anyhow!("Response contained no delta").into())],
355                                state,
356                            ));
357                        };
358
359                        let mut events = Vec::new();
360                        if let Some(content) = delta.content.clone() {
361                            events.push(Ok(LanguageModelCompletionEvent::Text(content)));
362                        }
363
364                        for tool_call in &delta.tool_calls {
365                            let entry = state
366                                .tool_calls_by_index
367                                .entry(tool_call.index)
368                                .or_default();
369
370                            if let Some(tool_id) = tool_call.id.clone() {
371                                entry.id = tool_id;
372                            }
373
374                            if let Some(function) = tool_call.function.as_ref() {
375                                if let Some(name) = function.name.clone() {
376                                    entry.name = name;
377                                }
378
379                                if let Some(arguments) = function.arguments.clone() {
380                                    entry.arguments.push_str(&arguments);
381                                }
382                            }
383                        }
384
385                        if let Some(usage) = event.usage {
386                            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
387                                TokenUsage {
388                                    input_tokens: usage.prompt_tokens,
389                                    output_tokens: usage.completion_tokens,
390                                    cache_creation_input_tokens: 0,
391                                    cache_read_input_tokens: 0,
392                                },
393                            )));
394                        }
395
396                        match choice.finish_reason.as_deref() {
397                            Some("stop") => {
398                                events.push(Ok(LanguageModelCompletionEvent::Stop(
399                                    StopReason::EndTurn,
400                                )));
401                            }
402                            Some("tool_calls") => {
403                                events.extend(state.tool_calls_by_index.drain().map(
404                                    |(_, tool_call)| {
405                                        // The model can output an empty string
406                                        // to indicate the absence of arguments.
407                                        // When that happens, create an empty
408                                        // object instead.
409                                        let arguments = if tool_call.arguments.is_empty() {
410                                            Ok(serde_json::Value::Object(Default::default()))
411                                        } else {
412                                            serde_json::Value::from_str(&tool_call.arguments)
413                                        };
414                                        match arguments {
415                                        Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
416                                            LanguageModelToolUse {
417                                                id: tool_call.id.clone().into(),
418                                                name: tool_call.name.as_str().into(),
419                                                is_input_complete: true,
420                                                input,
421                                                raw_input: tool_call.arguments.clone(),
422                                            },
423                                        )),
424                                        Err(error) => Ok(
425                                            LanguageModelCompletionEvent::ToolUseJsonParseError {
426                                                id: tool_call.id.into(),
427                                                tool_name: tool_call.name.as_str().into(),
428                                                raw_input: tool_call.arguments.into(),
429                                                json_parse_error: error.to_string(),
430                                            },
431                                        ),
432                                    }
433                                    },
434                                ));
435
436                                events.push(Ok(LanguageModelCompletionEvent::Stop(
437                                    StopReason::ToolUse,
438                                )));
439                            }
440                            Some(stop_reason) => {
441                                log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}");
442                                events.push(Ok(LanguageModelCompletionEvent::Stop(
443                                    StopReason::EndTurn,
444                                )));
445                            }
446                            None => {}
447                        }
448
449                        return Some((events, state));
450                    }
451                    Err(err) => return Some((vec![Err(anyhow!(err).into())], state)),
452                }
453            }
454
455            None
456        },
457    )
458    .flat_map(futures::stream::iter)
459}
460
461fn into_copilot_chat(
462    model: &copilot::copilot_chat::Model,
463    request: LanguageModelRequest,
464) -> Result<CopilotChatRequest> {
465    let mut request_messages: Vec<LanguageModelRequestMessage> = Vec::new();
466    for message in request.messages {
467        if let Some(last_message) = request_messages.last_mut() {
468            if last_message.role == message.role {
469                last_message.content.extend(message.content);
470            } else {
471                request_messages.push(message);
472            }
473        } else {
474            request_messages.push(message);
475        }
476    }
477
478    let mut messages: Vec<ChatMessage> = Vec::new();
479    for message in request_messages {
480        match message.role {
481            Role::User => {
482                for content in &message.content {
483                    if let MessageContent::ToolResult(tool_result) = content {
484                        let content = match &tool_result.content {
485                            LanguageModelToolResultContent::Text(text) => text.to_string().into(),
486                            LanguageModelToolResultContent::Image(image) => {
487                                if model.supports_vision() {
488                                    ChatMessageContent::Multipart(vec![ChatMessagePart::Image {
489                                        image_url: ImageUrl {
490                                            url: image.to_base64_url(),
491                                        },
492                                    }])
493                                } else {
494                                    debug_panic!(
495                                        "This should be caught at {} level",
496                                        tool_result.tool_name
497                                    );
498                                    "[Tool responded with an image, but this model does not support vision]".to_string().into()
499                                }
500                            }
501                        };
502
503                        messages.push(ChatMessage::Tool {
504                            tool_call_id: tool_result.tool_use_id.to_string(),
505                            content,
506                        });
507                    }
508                }
509
510                let mut content_parts = Vec::new();
511                for content in &message.content {
512                    match content {
513                        MessageContent::Text(text) | MessageContent::Thinking { text, .. }
514                            if !text.is_empty() =>
515                        {
516                            if let Some(ChatMessagePart::Text { text: text_content }) =
517                                content_parts.last_mut()
518                            {
519                                text_content.push_str(text);
520                            } else {
521                                content_parts.push(ChatMessagePart::Text {
522                                    text: text.to_string(),
523                                });
524                            }
525                        }
526                        MessageContent::Image(image) if model.supports_vision() => {
527                            content_parts.push(ChatMessagePart::Image {
528                                image_url: ImageUrl {
529                                    url: image.to_base64_url(),
530                                },
531                            });
532                        }
533                        _ => {}
534                    }
535                }
536
537                if !content_parts.is_empty() {
538                    messages.push(ChatMessage::User {
539                        content: content_parts.into(),
540                    });
541                }
542            }
543            Role::Assistant => {
544                let mut tool_calls = Vec::new();
545                for content in &message.content {
546                    if let MessageContent::ToolUse(tool_use) = content {
547                        tool_calls.push(ToolCall {
548                            id: tool_use.id.to_string(),
549                            content: copilot::copilot_chat::ToolCallContent::Function {
550                                function: copilot::copilot_chat::FunctionContent {
551                                    name: tool_use.name.to_string(),
552                                    arguments: serde_json::to_string(&tool_use.input)?,
553                                },
554                            },
555                        });
556                    }
557                }
558
559                let text_content = {
560                    let mut buffer = String::new();
561                    for string in message.content.iter().filter_map(|content| match content {
562                        MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
563                            Some(text.as_str())
564                        }
565                        MessageContent::ToolUse(_)
566                        | MessageContent::RedactedThinking(_)
567                        | MessageContent::ToolResult(_)
568                        | MessageContent::Image(_) => None,
569                    }) {
570                        buffer.push_str(string);
571                    }
572
573                    buffer
574                };
575
576                messages.push(ChatMessage::Assistant {
577                    content: if text_content.is_empty() {
578                        ChatMessageContent::empty()
579                    } else {
580                        text_content.into()
581                    },
582                    tool_calls,
583                });
584            }
585            Role::System => messages.push(ChatMessage::System {
586                content: message.string_contents(),
587            }),
588        }
589    }
590
591    let tools = request
592        .tools
593        .iter()
594        .map(|tool| Tool::Function {
595            function: copilot::copilot_chat::Function {
596                name: tool.name.clone(),
597                description: tool.description.clone(),
598                parameters: tool.input_schema.clone(),
599            },
600        })
601        .collect::<Vec<_>>();
602
603    Ok(CopilotChatRequest {
604        intent: true,
605        n: 1,
606        stream: model.uses_streaming(),
607        temperature: 0.1,
608        model: model.id().to_string(),
609        messages,
610        tools,
611        tool_choice: request.tool_choice.map(|choice| match choice {
612            LanguageModelToolChoice::Auto => copilot::copilot_chat::ToolChoice::Auto,
613            LanguageModelToolChoice::Any => copilot::copilot_chat::ToolChoice::Any,
614            LanguageModelToolChoice::None => copilot::copilot_chat::ToolChoice::None,
615        }),
616    })
617}
618
619struct ConfigurationView {
620    copilot_status: Option<copilot::Status>,
621    state: Entity<State>,
622    _subscription: Option<Subscription>,
623}
624
625impl ConfigurationView {
626    pub fn new(state: Entity<State>, cx: &mut Context<Self>) -> Self {
627        let copilot = Copilot::global(cx);
628
629        Self {
630            copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()),
631            state,
632            _subscription: copilot.as_ref().map(|copilot| {
633                cx.observe(copilot, |this, model, cx| {
634                    this.copilot_status = Some(model.read(cx).status());
635                    cx.notify();
636                })
637            }),
638        }
639    }
640}
641
642impl Render for ConfigurationView {
643    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
644        if self.state.read(cx).is_authenticated(cx) {
645            h_flex()
646                .mt_1()
647                .p_1()
648                .justify_between()
649                .rounded_md()
650                .border_1()
651                .border_color(cx.theme().colors().border)
652                .bg(cx.theme().colors().background)
653                .child(
654                    h_flex()
655                        .gap_1()
656                        .child(Icon::new(IconName::Check).color(Color::Success))
657                        .child(Label::new("Authorized")),
658                )
659                .child(
660                    Button::new("sign_out", "Sign Out")
661                        .label_size(LabelSize::Small)
662                        .on_click(|_, window, cx| {
663                            window.dispatch_action(copilot::SignOut.boxed_clone(), cx);
664                        }),
665                )
666        } else {
667            let loading_icon = Icon::new(IconName::ArrowCircle).with_animation(
668                "arrow-circle",
669                Animation::new(Duration::from_secs(4)).repeat(),
670                |icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
671            );
672
673            const ERROR_LABEL: &str = "Copilot Chat requires an active GitHub Copilot subscription. Please ensure Copilot is configured and try again, or use a different Assistant provider.";
674
675            match &self.copilot_status {
676                Some(status) => match status {
677                    Status::Starting { task: _ } => h_flex()
678                        .gap_2()
679                        .child(loading_icon)
680                        .child(Label::new("Starting Copilot…")),
681                    Status::SigningIn { prompt: _ }
682                    | Status::SignedOut {
683                        awaiting_signing_in: true,
684                    } => h_flex()
685                        .gap_2()
686                        .child(loading_icon)
687                        .child(Label::new("Signing into Copilot…")),
688                    Status::Error(_) => {
689                        const LABEL: &str = "Copilot had issues starting. Please try restarting it. If the issue persists, try reinstalling Copilot.";
690                        v_flex()
691                            .gap_6()
692                            .child(Label::new(LABEL))
693                            .child(svg().size_8().path(IconName::CopilotError.path()))
694                    }
695                    _ => {
696                        const LABEL: &str = "To use Zed's agent with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription.";
697
698                        v_flex().gap_2().child(Label::new(LABEL)).child(
699                            Button::new("sign_in", "Sign in to use GitHub Copilot")
700                                .icon_color(Color::Muted)
701                                .icon(IconName::Github)
702                                .icon_position(IconPosition::Start)
703                                .icon_size(IconSize::Medium)
704                                .full_width()
705                                .on_click(|_, window, cx| copilot::initiate_sign_in(window, cx)),
706                        )
707                    }
708                },
709                None => v_flex().gap_6().child(Label::new(ERROR_LABEL)),
710            }
711        }
712    }
713}