copilot_chat.rs

  1use std::pin::Pin;
  2use std::str::FromStr as _;
  3use std::sync::Arc;
  4
  5use anyhow::{Result, anyhow};
  6use cloud_llm_client::CompletionIntent;
  7use collections::HashMap;
  8use copilot::copilot_chat::{
  9    ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, ImageUrl,
 10    Model as CopilotChatModel, ModelVendor, Request as CopilotChatRequest, ResponseEvent, Tool,
 11    ToolCall,
 12};
 13use copilot::{Copilot, Status};
 14use futures::future::BoxFuture;
 15use futures::stream::BoxStream;
 16use futures::{FutureExt, Stream, StreamExt};
 17use gpui::{
 18    Action, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, Render, Subscription, Task,
 19    Transformation, percentage, svg,
 20};
 21use language::language_settings::all_language_settings;
 22use language_model::{
 23    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
 24    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
 25    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
 26    LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent,
 27    LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role,
 28    StopReason, TokenUsage,
 29};
 30use settings::SettingsStore;
 31use std::time::Duration;
 32use ui::prelude::*;
 33use util::debug_panic;
 34
 35use crate::provider::x_ai::count_xai_tokens;
 36
 37use super::anthropic::count_anthropic_tokens;
 38use super::google::count_google_tokens;
 39use super::open_ai::count_open_ai_tokens;
 40
 41const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat");
 42const PROVIDER_NAME: LanguageModelProviderName =
 43    LanguageModelProviderName::new("GitHub Copilot Chat");
 44
 45pub struct CopilotChatLanguageModelProvider {
 46    state: Entity<State>,
 47}
 48
 49pub struct State {
 50    _copilot_chat_subscription: Option<Subscription>,
 51    _settings_subscription: Subscription,
 52}
 53
 54impl State {
 55    fn is_authenticated(&self, cx: &App) -> bool {
 56        CopilotChat::global(cx)
 57            .map(|m| m.read(cx).is_authenticated())
 58            .unwrap_or(false)
 59    }
 60}
 61
 62impl CopilotChatLanguageModelProvider {
 63    pub fn new(cx: &mut App) -> Self {
 64        let state = cx.new(|cx| {
 65            let copilot_chat_subscription = CopilotChat::global(cx)
 66                .map(|copilot_chat| cx.observe(&copilot_chat, |_, _, cx| cx.notify()));
 67            State {
 68                _copilot_chat_subscription: copilot_chat_subscription,
 69                _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 70                    if let Some(copilot_chat) = CopilotChat::global(cx) {
 71                        let language_settings = all_language_settings(None, cx);
 72                        let configuration = copilot::copilot_chat::CopilotChatConfiguration {
 73                            enterprise_uri: language_settings
 74                                .edit_predictions
 75                                .copilot
 76                                .enterprise_uri
 77                                .clone(),
 78                        };
 79                        copilot_chat.update(cx, |chat, cx| {
 80                            chat.set_configuration(configuration, cx);
 81                        });
 82                    }
 83                    cx.notify();
 84                }),
 85            }
 86        });
 87
 88        Self { state }
 89    }
 90
 91    fn create_language_model(&self, model: CopilotChatModel) -> Arc<dyn LanguageModel> {
 92        Arc::new(CopilotChatLanguageModel {
 93            model,
 94            request_limiter: RateLimiter::new(4),
 95        })
 96    }
 97}
 98
 99impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
100    type ObservableEntity = State;
101
102    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
103        Some(self.state.clone())
104    }
105}
106
107impl LanguageModelProvider for CopilotChatLanguageModelProvider {
108    fn id(&self) -> LanguageModelProviderId {
109        PROVIDER_ID
110    }
111
112    fn name(&self) -> LanguageModelProviderName {
113        PROVIDER_NAME
114    }
115
116    fn icon(&self) -> IconName {
117        IconName::Copilot
118    }
119
120    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
121        let models = CopilotChat::global(cx).and_then(|m| m.read(cx).models())?;
122        models
123            .first()
124            .map(|model| self.create_language_model(model.clone()))
125    }
126
127    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
128        // The default model should be Copilot Chat's 'base model', which is likely a relatively fast
129        // model (e.g. 4o) and a sensible choice when considering premium requests
130        self.default_model(cx)
131    }
132
133    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
134        let Some(models) = CopilotChat::global(cx).and_then(|m| m.read(cx).models()) else {
135            return Vec::new();
136        };
137        models
138            .iter()
139            .map(|model| self.create_language_model(model.clone()))
140            .collect()
141    }
142
143    fn is_authenticated(&self, cx: &App) -> bool {
144        self.state.read(cx).is_authenticated(cx)
145    }
146
147    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
148        if self.is_authenticated(cx) {
149            return Task::ready(Ok(()));
150        };
151
152        let Some(copilot) = Copilot::global(cx) else {
153            return Task::ready( Err(anyhow!(
154                "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
155            ).into()));
156        };
157
158        let err = match copilot.read(cx).status() {
159            Status::Authorized => return Task::ready(Ok(())),
160            Status::Disabled => anyhow!(
161                "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
162            ),
163            Status::Error(err) => anyhow!(format!(
164                "Received the following error while signing into Copilot: {err}"
165            )),
166            Status::Starting { task: _ } => anyhow!(
167                "Copilot is still starting, please wait for Copilot to start then try again"
168            ),
169            Status::Unauthorized => anyhow!(
170                "Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."
171            ),
172            Status::SignedOut { .. } => {
173                anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again.")
174            }
175            Status::SigningIn { prompt: _ } => anyhow!("Still signing into Copilot..."),
176        };
177
178        Task::ready(Err(err.into()))
179    }
180
181    fn configuration_view(
182        &self,
183        _target_agent: language_model::ConfigurationViewTargetAgent,
184        _: &mut Window,
185        cx: &mut App,
186    ) -> AnyView {
187        let state = self.state.clone();
188        cx.new(|cx| ConfigurationView::new(state, cx)).into()
189    }
190
191    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
192        Task::ready(Err(anyhow!(
193            "Signing out of GitHub Copilot Chat is currently not supported."
194        )))
195    }
196}
197
198pub struct CopilotChatLanguageModel {
199    model: CopilotChatModel,
200    request_limiter: RateLimiter,
201}
202
203impl LanguageModel for CopilotChatLanguageModel {
204    fn id(&self) -> LanguageModelId {
205        LanguageModelId::from(self.model.id().to_string())
206    }
207
208    fn name(&self) -> LanguageModelName {
209        LanguageModelName::from(self.model.display_name().to_string())
210    }
211
212    fn provider_id(&self) -> LanguageModelProviderId {
213        PROVIDER_ID
214    }
215
216    fn provider_name(&self) -> LanguageModelProviderName {
217        PROVIDER_NAME
218    }
219
220    fn supports_tools(&self) -> bool {
221        self.model.supports_tools()
222    }
223
224    fn supports_images(&self) -> bool {
225        self.model.supports_vision()
226    }
227
228    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
229        match self.model.vendor() {
230            ModelVendor::OpenAI | ModelVendor::Anthropic => {
231                LanguageModelToolSchemaFormat::JsonSchema
232            }
233            ModelVendor::Google | ModelVendor::XAI => {
234                LanguageModelToolSchemaFormat::JsonSchemaSubset
235            }
236        }
237    }
238
239    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
240        match choice {
241            LanguageModelToolChoice::Auto
242            | LanguageModelToolChoice::Any
243            | LanguageModelToolChoice::None => self.supports_tools(),
244        }
245    }
246
247    fn telemetry_id(&self) -> String {
248        format!("copilot_chat/{}", self.model.id())
249    }
250
251    fn max_token_count(&self) -> u64 {
252        self.model.max_token_count()
253    }
254
255    fn count_tokens(
256        &self,
257        request: LanguageModelRequest,
258        cx: &App,
259    ) -> BoxFuture<'static, Result<u64>> {
260        match self.model.vendor() {
261            ModelVendor::Anthropic => count_anthropic_tokens(request, cx),
262            ModelVendor::Google => count_google_tokens(request, cx),
263            ModelVendor::XAI => {
264                let model = x_ai::Model::from_id(self.model.id()).unwrap_or_default();
265                count_xai_tokens(request, model, cx)
266            }
267            ModelVendor::OpenAI => {
268                let model = open_ai::Model::from_id(self.model.id()).unwrap_or_default();
269                count_open_ai_tokens(request, model, cx)
270            }
271        }
272    }
273
274    fn stream_completion(
275        &self,
276        request: LanguageModelRequest,
277        cx: &AsyncApp,
278    ) -> BoxFuture<
279        'static,
280        Result<
281            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
282            LanguageModelCompletionError,
283        >,
284    > {
285        let is_user_initiated = request.intent.is_none_or(|intent| match intent {
286            CompletionIntent::UserPrompt
287            | CompletionIntent::ThreadContextSummarization
288            | CompletionIntent::InlineAssist
289            | CompletionIntent::TerminalInlineAssist
290            | CompletionIntent::GenerateGitCommitMessage => true,
291
292            CompletionIntent::ToolResults
293            | CompletionIntent::ThreadSummarization
294            | CompletionIntent::CreateFile
295            | CompletionIntent::EditFile => false,
296        });
297
298        let copilot_request = match into_copilot_chat(&self.model, request) {
299            Ok(request) => request,
300            Err(err) => return futures::future::ready(Err(err.into())).boxed(),
301        };
302        let is_streaming = copilot_request.stream;
303
304        let request_limiter = self.request_limiter.clone();
305        let future = cx.spawn(async move |cx| {
306            let request =
307                CopilotChat::stream_completion(copilot_request, is_user_initiated, cx.clone());
308            request_limiter
309                .stream(async move {
310                    let response = request.await?;
311                    Ok(map_to_language_model_completion_events(
312                        response,
313                        is_streaming,
314                    ))
315                })
316                .await
317        });
318        async move { Ok(future.await?.boxed()) }.boxed()
319    }
320}
321
322pub fn map_to_language_model_completion_events(
323    events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
324    is_streaming: bool,
325) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
326    #[derive(Default)]
327    struct RawToolCall {
328        id: String,
329        name: String,
330        arguments: String,
331    }
332
333    struct State {
334        events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
335        tool_calls_by_index: HashMap<usize, RawToolCall>,
336    }
337
338    futures::stream::unfold(
339        State {
340            events,
341            tool_calls_by_index: HashMap::default(),
342        },
343        move |mut state| async move {
344            if let Some(event) = state.events.next().await {
345                match event {
346                    Ok(event) => {
347                        let Some(choice) = event.choices.first() else {
348                            return Some((
349                                vec![Err(anyhow!("Response contained no choices").into())],
350                                state,
351                            ));
352                        };
353
354                        let delta = if is_streaming {
355                            choice.delta.as_ref()
356                        } else {
357                            choice.message.as_ref()
358                        };
359
360                        let Some(delta) = delta else {
361                            return Some((
362                                vec![Err(anyhow!("Response contained no delta").into())],
363                                state,
364                            ));
365                        };
366
367                        let mut events = Vec::new();
368                        if let Some(content) = delta.content.clone() {
369                            events.push(Ok(LanguageModelCompletionEvent::Text(content)));
370                        }
371
372                        for tool_call in &delta.tool_calls {
373                            let entry = state
374                                .tool_calls_by_index
375                                .entry(tool_call.index)
376                                .or_default();
377
378                            if let Some(tool_id) = tool_call.id.clone() {
379                                entry.id = tool_id;
380                            }
381
382                            if let Some(function) = tool_call.function.as_ref() {
383                                if let Some(name) = function.name.clone() {
384                                    entry.name = name;
385                                }
386
387                                if let Some(arguments) = function.arguments.clone() {
388                                    entry.arguments.push_str(&arguments);
389                                }
390                            }
391                        }
392
393                        if let Some(usage) = event.usage {
394                            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
395                                TokenUsage {
396                                    input_tokens: usage.prompt_tokens,
397                                    output_tokens: usage.completion_tokens,
398                                    cache_creation_input_tokens: 0,
399                                    cache_read_input_tokens: 0,
400                                },
401                            )));
402                        }
403
404                        match choice.finish_reason.as_deref() {
405                            Some("stop") => {
406                                events.push(Ok(LanguageModelCompletionEvent::Stop(
407                                    StopReason::EndTurn,
408                                )));
409                            }
410                            Some("tool_calls") => {
411                                events.extend(state.tool_calls_by_index.drain().map(
412                                    |(_, tool_call)| {
413                                        // The model can output an empty string
414                                        // to indicate the absence of arguments.
415                                        // When that happens, create an empty
416                                        // object instead.
417                                        let arguments = if tool_call.arguments.is_empty() {
418                                            Ok(serde_json::Value::Object(Default::default()))
419                                        } else {
420                                            serde_json::Value::from_str(&tool_call.arguments)
421                                        };
422                                        match arguments {
423                                        Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
424                                            LanguageModelToolUse {
425                                                id: tool_call.id.clone().into(),
426                                                name: tool_call.name.as_str().into(),
427                                                is_input_complete: true,
428                                                input,
429                                                raw_input: tool_call.arguments.clone(),
430                                            },
431                                        )),
432                                        Err(error) => Ok(
433                                            LanguageModelCompletionEvent::ToolUseJsonParseError {
434                                                id: tool_call.id.into(),
435                                                tool_name: tool_call.name.as_str().into(),
436                                                raw_input: tool_call.arguments.into(),
437                                                json_parse_error: error.to_string(),
438                                            },
439                                        ),
440                                    }
441                                    },
442                                ));
443
444                                events.push(Ok(LanguageModelCompletionEvent::Stop(
445                                    StopReason::ToolUse,
446                                )));
447                            }
448                            Some(stop_reason) => {
449                                log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}");
450                                events.push(Ok(LanguageModelCompletionEvent::Stop(
451                                    StopReason::EndTurn,
452                                )));
453                            }
454                            None => {}
455                        }
456
457                        return Some((events, state));
458                    }
459                    Err(err) => return Some((vec![Err(anyhow!(err).into())], state)),
460                }
461            }
462
463            None
464        },
465    )
466    .flat_map(futures::stream::iter)
467}
468
469fn into_copilot_chat(
470    model: &copilot::copilot_chat::Model,
471    request: LanguageModelRequest,
472) -> Result<CopilotChatRequest> {
473    let mut request_messages: Vec<LanguageModelRequestMessage> = Vec::new();
474    for message in request.messages {
475        if let Some(last_message) = request_messages.last_mut() {
476            if last_message.role == message.role {
477                last_message.content.extend(message.content);
478            } else {
479                request_messages.push(message);
480            }
481        } else {
482            request_messages.push(message);
483        }
484    }
485
486    let mut messages: Vec<ChatMessage> = Vec::new();
487    for message in request_messages {
488        match message.role {
489            Role::User => {
490                for content in &message.content {
491                    if let MessageContent::ToolResult(tool_result) = content {
492                        let content = match &tool_result.content {
493                            LanguageModelToolResultContent::Text(text) => text.to_string().into(),
494                            LanguageModelToolResultContent::Image(image) => {
495                                if model.supports_vision() {
496                                    ChatMessageContent::Multipart(vec![ChatMessagePart::Image {
497                                        image_url: ImageUrl {
498                                            url: image.to_base64_url(),
499                                        },
500                                    }])
501                                } else {
502                                    debug_panic!(
503                                        "This should be caught at {} level",
504                                        tool_result.tool_name
505                                    );
506                                    "[Tool responded with an image, but this model does not support vision]".to_string().into()
507                                }
508                            }
509                        };
510
511                        messages.push(ChatMessage::Tool {
512                            tool_call_id: tool_result.tool_use_id.to_string(),
513                            content,
514                        });
515                    }
516                }
517
518                let mut content_parts = Vec::new();
519                for content in &message.content {
520                    match content {
521                        MessageContent::Text(text) | MessageContent::Thinking { text, .. }
522                            if !text.is_empty() =>
523                        {
524                            if let Some(ChatMessagePart::Text { text: text_content }) =
525                                content_parts.last_mut()
526                            {
527                                text_content.push_str(text);
528                            } else {
529                                content_parts.push(ChatMessagePart::Text {
530                                    text: text.to_string(),
531                                });
532                            }
533                        }
534                        MessageContent::Image(image) if model.supports_vision() => {
535                            content_parts.push(ChatMessagePart::Image {
536                                image_url: ImageUrl {
537                                    url: image.to_base64_url(),
538                                },
539                            });
540                        }
541                        _ => {}
542                    }
543                }
544
545                if !content_parts.is_empty() {
546                    messages.push(ChatMessage::User {
547                        content: content_parts.into(),
548                    });
549                }
550            }
551            Role::Assistant => {
552                let mut tool_calls = Vec::new();
553                for content in &message.content {
554                    if let MessageContent::ToolUse(tool_use) = content {
555                        tool_calls.push(ToolCall {
556                            id: tool_use.id.to_string(),
557                            content: copilot::copilot_chat::ToolCallContent::Function {
558                                function: copilot::copilot_chat::FunctionContent {
559                                    name: tool_use.name.to_string(),
560                                    arguments: serde_json::to_string(&tool_use.input)?,
561                                },
562                            },
563                        });
564                    }
565                }
566
567                let text_content = {
568                    let mut buffer = String::new();
569                    for string in message.content.iter().filter_map(|content| match content {
570                        MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
571                            Some(text.as_str())
572                        }
573                        MessageContent::ToolUse(_)
574                        | MessageContent::RedactedThinking(_)
575                        | MessageContent::ToolResult(_)
576                        | MessageContent::Image(_) => None,
577                    }) {
578                        buffer.push_str(string);
579                    }
580
581                    buffer
582                };
583
584                messages.push(ChatMessage::Assistant {
585                    content: if text_content.is_empty() {
586                        ChatMessageContent::empty()
587                    } else {
588                        text_content.into()
589                    },
590                    tool_calls,
591                });
592            }
593            Role::System => messages.push(ChatMessage::System {
594                content: message.string_contents(),
595            }),
596        }
597    }
598
599    let tools = request
600        .tools
601        .iter()
602        .map(|tool| Tool::Function {
603            function: copilot::copilot_chat::Function {
604                name: tool.name.clone(),
605                description: tool.description.clone(),
606                parameters: tool.input_schema.clone(),
607            },
608        })
609        .collect::<Vec<_>>();
610
611    Ok(CopilotChatRequest {
612        intent: true,
613        n: 1,
614        stream: model.uses_streaming(),
615        temperature: 0.1,
616        model: model.id().to_string(),
617        messages,
618        tools,
619        tool_choice: request.tool_choice.map(|choice| match choice {
620            LanguageModelToolChoice::Auto => copilot::copilot_chat::ToolChoice::Auto,
621            LanguageModelToolChoice::Any => copilot::copilot_chat::ToolChoice::Any,
622            LanguageModelToolChoice::None => copilot::copilot_chat::ToolChoice::None,
623        }),
624    })
625}
626
627struct ConfigurationView {
628    copilot_status: Option<copilot::Status>,
629    state: Entity<State>,
630    _subscription: Option<Subscription>,
631}
632
633impl ConfigurationView {
634    pub fn new(state: Entity<State>, cx: &mut Context<Self>) -> Self {
635        let copilot = Copilot::global(cx);
636
637        Self {
638            copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()),
639            state,
640            _subscription: copilot.as_ref().map(|copilot| {
641                cx.observe(copilot, |this, model, cx| {
642                    this.copilot_status = Some(model.read(cx).status());
643                    cx.notify();
644                })
645            }),
646        }
647    }
648}
649
650impl Render for ConfigurationView {
651    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
652        if self.state.read(cx).is_authenticated(cx) {
653            h_flex()
654                .mt_1()
655                .p_1()
656                .justify_between()
657                .rounded_md()
658                .border_1()
659                .border_color(cx.theme().colors().border)
660                .bg(cx.theme().colors().background)
661                .child(
662                    h_flex()
663                        .gap_1()
664                        .child(Icon::new(IconName::Check).color(Color::Success))
665                        .child(Label::new("Authorized")),
666                )
667                .child(
668                    Button::new("sign_out", "Sign Out")
669                        .label_size(LabelSize::Small)
670                        .on_click(|_, window, cx| {
671                            window.dispatch_action(copilot::SignOut.boxed_clone(), cx);
672                        }),
673                )
674        } else {
675            let loading_icon = Icon::new(IconName::ArrowCircle).with_animation(
676                "arrow-circle",
677                Animation::new(Duration::from_secs(4)).repeat(),
678                |icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
679            );
680
681            const ERROR_LABEL: &str = "Copilot Chat requires an active GitHub Copilot subscription. Please ensure Copilot is configured and try again, or use a different Assistant provider.";
682
683            match &self.copilot_status {
684                Some(status) => match status {
685                    Status::Starting { task: _ } => h_flex()
686                        .gap_2()
687                        .child(loading_icon)
688                        .child(Label::new("Starting Copilot…")),
689                    Status::SigningIn { prompt: _ }
690                    | Status::SignedOut {
691                        awaiting_signing_in: true,
692                    } => h_flex()
693                        .gap_2()
694                        .child(loading_icon)
695                        .child(Label::new("Signing into Copilot…")),
696                    Status::Error(_) => {
697                        const LABEL: &str = "Copilot had issues starting. Please try restarting it. If the issue persists, try reinstalling Copilot.";
698                        v_flex()
699                            .gap_6()
700                            .child(Label::new(LABEL))
701                            .child(svg().size_8().path(IconName::CopilotError.path()))
702                    }
703                    _ => {
704                        const LABEL: &str = "To use Zed's agent with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription.";
705
706                        v_flex().gap_2().child(Label::new(LABEL)).child(
707                            Button::new("sign_in", "Sign in to use GitHub Copilot")
708                                .icon_color(Color::Muted)
709                                .icon(IconName::Github)
710                                .icon_position(IconPosition::Start)
711                                .icon_size(IconSize::Medium)
712                                .full_width()
713                                .on_click(|_, window, cx| copilot::initiate_sign_in(window, cx)),
714                        )
715                    }
716                },
717                None => v_flex().gap_6().child(Label::new(ERROR_LABEL)),
718            }
719        }
720    }
721}