1use std::pin::Pin;
2use std::str::FromStr as _;
3use std::sync::Arc;
4
5use anyhow::{Result, anyhow};
6use cloud_llm_client::CompletionIntent;
7use collections::HashMap;
8use copilot::copilot_chat::{
9 ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, ImageUrl,
10 Model as CopilotChatModel, ModelVendor, Request as CopilotChatRequest, ResponseEvent, Tool,
11 ToolCall,
12};
13use copilot::{Copilot, Status};
14use futures::future::BoxFuture;
15use futures::stream::BoxStream;
16use futures::{FutureExt, Stream, StreamExt};
17use gpui::{
18 Action, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, Render, Subscription, Task,
19 Transformation, percentage, svg,
20};
21use language::language_settings::all_language_settings;
22use language_model::{
23 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
24 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
25 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
26 LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent,
27 LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role,
28 StopReason, TokenUsage,
29};
30use settings::SettingsStore;
31use std::time::Duration;
32use ui::prelude::*;
33use util::debug_panic;
34
35use super::anthropic::count_anthropic_tokens;
36use super::google::count_google_tokens;
37use super::open_ai::count_open_ai_tokens;
38
39const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat");
40const PROVIDER_NAME: LanguageModelProviderName =
41 LanguageModelProviderName::new("GitHub Copilot Chat");
42
43pub struct CopilotChatLanguageModelProvider {
44 state: Entity<State>,
45}
46
47pub struct State {
48 _copilot_chat_subscription: Option<Subscription>,
49 _settings_subscription: Subscription,
50}
51
52impl State {
53 fn is_authenticated(&self, cx: &App) -> bool {
54 CopilotChat::global(cx)
55 .map(|m| m.read(cx).is_authenticated())
56 .unwrap_or(false)
57 }
58}
59
60impl CopilotChatLanguageModelProvider {
61 pub fn new(cx: &mut App) -> Self {
62 let state = cx.new(|cx| {
63 let copilot_chat_subscription = CopilotChat::global(cx)
64 .map(|copilot_chat| cx.observe(&copilot_chat, |_, _, cx| cx.notify()));
65 State {
66 _copilot_chat_subscription: copilot_chat_subscription,
67 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
68 if let Some(copilot_chat) = CopilotChat::global(cx) {
69 let language_settings = all_language_settings(None, cx);
70 let configuration = copilot::copilot_chat::CopilotChatConfiguration {
71 enterprise_uri: language_settings
72 .edit_predictions
73 .copilot
74 .enterprise_uri
75 .clone(),
76 };
77 copilot_chat.update(cx, |chat, cx| {
78 chat.set_configuration(configuration, cx);
79 });
80 }
81 cx.notify();
82 }),
83 }
84 });
85
86 Self { state }
87 }
88
89 fn create_language_model(&self, model: CopilotChatModel) -> Arc<dyn LanguageModel> {
90 Arc::new(CopilotChatLanguageModel {
91 model,
92 request_limiter: RateLimiter::new(4),
93 })
94 }
95}
96
97impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
98 type ObservableEntity = State;
99
100 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
101 Some(self.state.clone())
102 }
103}
104
105impl LanguageModelProvider for CopilotChatLanguageModelProvider {
106 fn id(&self) -> LanguageModelProviderId {
107 PROVIDER_ID
108 }
109
110 fn name(&self) -> LanguageModelProviderName {
111 PROVIDER_NAME
112 }
113
114 fn icon(&self) -> IconName {
115 IconName::Copilot
116 }
117
118 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
119 let models = CopilotChat::global(cx).and_then(|m| m.read(cx).models())?;
120 models
121 .first()
122 .map(|model| self.create_language_model(model.clone()))
123 }
124
125 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
126 // The default model should be Copilot Chat's 'base model', which is likely a relatively fast
127 // model (e.g. 4o) and a sensible choice when considering premium requests
128 self.default_model(cx)
129 }
130
131 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
132 let Some(models) = CopilotChat::global(cx).and_then(|m| m.read(cx).models()) else {
133 return Vec::new();
134 };
135 models
136 .iter()
137 .map(|model| self.create_language_model(model.clone()))
138 .collect()
139 }
140
141 fn is_authenticated(&self, cx: &App) -> bool {
142 self.state.read(cx).is_authenticated(cx)
143 }
144
145 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
146 if self.is_authenticated(cx) {
147 return Task::ready(Ok(()));
148 };
149
150 let Some(copilot) = Copilot::global(cx) else {
151 return Task::ready( Err(anyhow!(
152 "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
153 ).into()));
154 };
155
156 let err = match copilot.read(cx).status() {
157 Status::Authorized => return Task::ready(Ok(())),
158 Status::Disabled => anyhow!(
159 "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
160 ),
161 Status::Error(err) => anyhow!(format!(
162 "Received the following error while signing into Copilot: {err}"
163 )),
164 Status::Starting { task: _ } => anyhow!(
165 "Copilot is still starting, please wait for Copilot to start then try again"
166 ),
167 Status::Unauthorized => anyhow!(
168 "Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."
169 ),
170 Status::SignedOut { .. } => {
171 anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again.")
172 }
173 Status::SigningIn { prompt: _ } => anyhow!("Still signing into Copilot..."),
174 };
175
176 Task::ready(Err(err.into()))
177 }
178
179 fn configuration_view(
180 &self,
181 _target_agent: language_model::ConfigurationViewTargetAgent,
182 _: &mut Window,
183 cx: &mut App,
184 ) -> AnyView {
185 let state = self.state.clone();
186 cx.new(|cx| ConfigurationView::new(state, cx)).into()
187 }
188
189 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
190 Task::ready(Err(anyhow!(
191 "Signing out of GitHub Copilot Chat is currently not supported."
192 )))
193 }
194}
195
196pub struct CopilotChatLanguageModel {
197 model: CopilotChatModel,
198 request_limiter: RateLimiter,
199}
200
201impl LanguageModel for CopilotChatLanguageModel {
202 fn id(&self) -> LanguageModelId {
203 LanguageModelId::from(self.model.id().to_string())
204 }
205
206 fn name(&self) -> LanguageModelName {
207 LanguageModelName::from(self.model.display_name().to_string())
208 }
209
210 fn provider_id(&self) -> LanguageModelProviderId {
211 PROVIDER_ID
212 }
213
214 fn provider_name(&self) -> LanguageModelProviderName {
215 PROVIDER_NAME
216 }
217
218 fn supports_tools(&self) -> bool {
219 self.model.supports_tools()
220 }
221
222 fn supports_images(&self) -> bool {
223 self.model.supports_vision()
224 }
225
226 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
227 match self.model.vendor() {
228 ModelVendor::OpenAI | ModelVendor::Anthropic => {
229 LanguageModelToolSchemaFormat::JsonSchema
230 }
231 ModelVendor::Google => LanguageModelToolSchemaFormat::JsonSchemaSubset,
232 }
233 }
234
235 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
236 match choice {
237 LanguageModelToolChoice::Auto
238 | LanguageModelToolChoice::Any
239 | LanguageModelToolChoice::None => self.supports_tools(),
240 }
241 }
242
243 fn telemetry_id(&self) -> String {
244 format!("copilot_chat/{}", self.model.id())
245 }
246
247 fn max_token_count(&self) -> u64 {
248 self.model.max_token_count()
249 }
250
251 fn count_tokens(
252 &self,
253 request: LanguageModelRequest,
254 cx: &App,
255 ) -> BoxFuture<'static, Result<u64>> {
256 match self.model.vendor() {
257 ModelVendor::Anthropic => count_anthropic_tokens(request, cx),
258 ModelVendor::Google => count_google_tokens(request, cx),
259 ModelVendor::OpenAI => {
260 let model = open_ai::Model::from_id(self.model.id()).unwrap_or_default();
261 count_open_ai_tokens(request, model, cx)
262 }
263 }
264 }
265
266 fn stream_completion(
267 &self,
268 request: LanguageModelRequest,
269 cx: &AsyncApp,
270 ) -> BoxFuture<
271 'static,
272 Result<
273 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
274 LanguageModelCompletionError,
275 >,
276 > {
277 let is_user_initiated = request.intent.is_none_or(|intent| match intent {
278 CompletionIntent::UserPrompt
279 | CompletionIntent::ThreadContextSummarization
280 | CompletionIntent::InlineAssist
281 | CompletionIntent::TerminalInlineAssist
282 | CompletionIntent::GenerateGitCommitMessage => true,
283
284 CompletionIntent::ToolResults
285 | CompletionIntent::ThreadSummarization
286 | CompletionIntent::CreateFile
287 | CompletionIntent::EditFile => false,
288 });
289
290 let copilot_request = match into_copilot_chat(&self.model, request) {
291 Ok(request) => request,
292 Err(err) => return futures::future::ready(Err(err.into())).boxed(),
293 };
294 let is_streaming = copilot_request.stream;
295
296 let request_limiter = self.request_limiter.clone();
297 let future = cx.spawn(async move |cx| {
298 let request =
299 CopilotChat::stream_completion(copilot_request, is_user_initiated, cx.clone());
300 request_limiter
301 .stream(async move {
302 let response = request.await?;
303 Ok(map_to_language_model_completion_events(
304 response,
305 is_streaming,
306 ))
307 })
308 .await
309 });
310 async move { Ok(future.await?.boxed()) }.boxed()
311 }
312}
313
314pub fn map_to_language_model_completion_events(
315 events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
316 is_streaming: bool,
317) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
318 #[derive(Default)]
319 struct RawToolCall {
320 id: String,
321 name: String,
322 arguments: String,
323 }
324
325 struct State {
326 events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
327 tool_calls_by_index: HashMap<usize, RawToolCall>,
328 }
329
330 futures::stream::unfold(
331 State {
332 events,
333 tool_calls_by_index: HashMap::default(),
334 },
335 move |mut state| async move {
336 if let Some(event) = state.events.next().await {
337 match event {
338 Ok(event) => {
339 let Some(choice) = event.choices.first() else {
340 return Some((
341 vec![Err(anyhow!("Response contained no choices").into())],
342 state,
343 ));
344 };
345
346 let delta = if is_streaming {
347 choice.delta.as_ref()
348 } else {
349 choice.message.as_ref()
350 };
351
352 let Some(delta) = delta else {
353 return Some((
354 vec![Err(anyhow!("Response contained no delta").into())],
355 state,
356 ));
357 };
358
359 let mut events = Vec::new();
360 if let Some(content) = delta.content.clone() {
361 events.push(Ok(LanguageModelCompletionEvent::Text(content)));
362 }
363
364 for tool_call in &delta.tool_calls {
365 let entry = state
366 .tool_calls_by_index
367 .entry(tool_call.index)
368 .or_default();
369
370 if let Some(tool_id) = tool_call.id.clone() {
371 entry.id = tool_id;
372 }
373
374 if let Some(function) = tool_call.function.as_ref() {
375 if let Some(name) = function.name.clone() {
376 entry.name = name;
377 }
378
379 if let Some(arguments) = function.arguments.clone() {
380 entry.arguments.push_str(&arguments);
381 }
382 }
383 }
384
385 if let Some(usage) = event.usage {
386 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
387 TokenUsage {
388 input_tokens: usage.prompt_tokens,
389 output_tokens: usage.completion_tokens,
390 cache_creation_input_tokens: 0,
391 cache_read_input_tokens: 0,
392 },
393 )));
394 }
395
396 match choice.finish_reason.as_deref() {
397 Some("stop") => {
398 events.push(Ok(LanguageModelCompletionEvent::Stop(
399 StopReason::EndTurn,
400 )));
401 }
402 Some("tool_calls") => {
403 events.extend(state.tool_calls_by_index.drain().map(
404 |(_, tool_call)| {
405 // The model can output an empty string
406 // to indicate the absence of arguments.
407 // When that happens, create an empty
408 // object instead.
409 let arguments = if tool_call.arguments.is_empty() {
410 Ok(serde_json::Value::Object(Default::default()))
411 } else {
412 serde_json::Value::from_str(&tool_call.arguments)
413 };
414 match arguments {
415 Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
416 LanguageModelToolUse {
417 id: tool_call.id.clone().into(),
418 name: tool_call.name.as_str().into(),
419 is_input_complete: true,
420 input,
421 raw_input: tool_call.arguments.clone(),
422 },
423 )),
424 Err(error) => Ok(
425 LanguageModelCompletionEvent::ToolUseJsonParseError {
426 id: tool_call.id.into(),
427 tool_name: tool_call.name.as_str().into(),
428 raw_input: tool_call.arguments.into(),
429 json_parse_error: error.to_string(),
430 },
431 ),
432 }
433 },
434 ));
435
436 events.push(Ok(LanguageModelCompletionEvent::Stop(
437 StopReason::ToolUse,
438 )));
439 }
440 Some(stop_reason) => {
441 log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}");
442 events.push(Ok(LanguageModelCompletionEvent::Stop(
443 StopReason::EndTurn,
444 )));
445 }
446 None => {}
447 }
448
449 return Some((events, state));
450 }
451 Err(err) => return Some((vec![Err(anyhow!(err).into())], state)),
452 }
453 }
454
455 None
456 },
457 )
458 .flat_map(futures::stream::iter)
459}
460
461fn into_copilot_chat(
462 model: &copilot::copilot_chat::Model,
463 request: LanguageModelRequest,
464) -> Result<CopilotChatRequest> {
465 let mut request_messages: Vec<LanguageModelRequestMessage> = Vec::new();
466 for message in request.messages {
467 if let Some(last_message) = request_messages.last_mut() {
468 if last_message.role == message.role {
469 last_message.content.extend(message.content);
470 } else {
471 request_messages.push(message);
472 }
473 } else {
474 request_messages.push(message);
475 }
476 }
477
478 let mut tool_called = false;
479 let mut messages: Vec<ChatMessage> = Vec::new();
480 for message in request_messages {
481 match message.role {
482 Role::User => {
483 for content in &message.content {
484 if let MessageContent::ToolResult(tool_result) = content {
485 let content = match &tool_result.content {
486 LanguageModelToolResultContent::Text(text) => text.to_string().into(),
487 LanguageModelToolResultContent::Image(image) => {
488 if model.supports_vision() {
489 ChatMessageContent::Multipart(vec![ChatMessagePart::Image {
490 image_url: ImageUrl {
491 url: image.to_base64_url(),
492 },
493 }])
494 } else {
495 debug_panic!(
496 "This should be caught at {} level",
497 tool_result.tool_name
498 );
499 "[Tool responded with an image, but this model does not support vision]".to_string().into()
500 }
501 }
502 };
503
504 messages.push(ChatMessage::Tool {
505 tool_call_id: tool_result.tool_use_id.to_string(),
506 content,
507 });
508 }
509 }
510
511 let mut content_parts = Vec::new();
512 for content in &message.content {
513 match content {
514 MessageContent::Text(text) | MessageContent::Thinking { text, .. }
515 if !text.is_empty() =>
516 {
517 if let Some(ChatMessagePart::Text { text: text_content }) =
518 content_parts.last_mut()
519 {
520 text_content.push_str(text);
521 } else {
522 content_parts.push(ChatMessagePart::Text {
523 text: text.to_string(),
524 });
525 }
526 }
527 MessageContent::Image(image) if model.supports_vision() => {
528 content_parts.push(ChatMessagePart::Image {
529 image_url: ImageUrl {
530 url: image.to_base64_url(),
531 },
532 });
533 }
534 _ => {}
535 }
536 }
537
538 if !content_parts.is_empty() {
539 messages.push(ChatMessage::User {
540 content: content_parts.into(),
541 });
542 }
543 }
544 Role::Assistant => {
545 let mut tool_calls = Vec::new();
546 for content in &message.content {
547 if let MessageContent::ToolUse(tool_use) = content {
548 tool_called = true;
549 tool_calls.push(ToolCall {
550 id: tool_use.id.to_string(),
551 content: copilot::copilot_chat::ToolCallContent::Function {
552 function: copilot::copilot_chat::FunctionContent {
553 name: tool_use.name.to_string(),
554 arguments: serde_json::to_string(&tool_use.input)?,
555 },
556 },
557 });
558 }
559 }
560
561 let text_content = {
562 let mut buffer = String::new();
563 for string in message.content.iter().filter_map(|content| match content {
564 MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
565 Some(text.as_str())
566 }
567 MessageContent::ToolUse(_)
568 | MessageContent::RedactedThinking(_)
569 | MessageContent::ToolResult(_)
570 | MessageContent::Image(_) => None,
571 }) {
572 buffer.push_str(string);
573 }
574
575 buffer
576 };
577
578 messages.push(ChatMessage::Assistant {
579 content: if text_content.is_empty() {
580 ChatMessageContent::empty()
581 } else {
582 text_content.into()
583 },
584 tool_calls,
585 });
586 }
587 Role::System => messages.push(ChatMessage::System {
588 content: message.string_contents(),
589 }),
590 }
591 }
592
593 let mut tools = request
594 .tools
595 .iter()
596 .map(|tool| Tool::Function {
597 function: copilot::copilot_chat::Function {
598 name: tool.name.clone(),
599 description: tool.description.clone(),
600 parameters: tool.input_schema.clone(),
601 },
602 })
603 .collect::<Vec<_>>();
604
605 // The API will return a Bad Request (with no error message) when tools
606 // were used previously in the conversation but no tools are provided as
607 // part of this request. Inserting a dummy tool seems to circumvent this
608 // error.
609 if tool_called && tools.is_empty() {
610 tools.push(Tool::Function {
611 function: copilot::copilot_chat::Function {
612 name: "noop".to_string(),
613 description: "No operation".to_string(),
614 parameters: serde_json::json!({
615 "type": "object"
616 }),
617 },
618 });
619 }
620
621 Ok(CopilotChatRequest {
622 intent: true,
623 n: 1,
624 stream: model.uses_streaming(),
625 temperature: 0.1,
626 model: model.id().to_string(),
627 messages,
628 tools,
629 tool_choice: request.tool_choice.map(|choice| match choice {
630 LanguageModelToolChoice::Auto => copilot::copilot_chat::ToolChoice::Auto,
631 LanguageModelToolChoice::Any => copilot::copilot_chat::ToolChoice::Any,
632 LanguageModelToolChoice::None => copilot::copilot_chat::ToolChoice::None,
633 }),
634 })
635}
636
637struct ConfigurationView {
638 copilot_status: Option<copilot::Status>,
639 state: Entity<State>,
640 _subscription: Option<Subscription>,
641}
642
643impl ConfigurationView {
644 pub fn new(state: Entity<State>, cx: &mut Context<Self>) -> Self {
645 let copilot = Copilot::global(cx);
646
647 Self {
648 copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()),
649 state,
650 _subscription: copilot.as_ref().map(|copilot| {
651 cx.observe(copilot, |this, model, cx| {
652 this.copilot_status = Some(model.read(cx).status());
653 cx.notify();
654 })
655 }),
656 }
657 }
658}
659
660impl Render for ConfigurationView {
661 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
662 if self.state.read(cx).is_authenticated(cx) {
663 h_flex()
664 .mt_1()
665 .p_1()
666 .justify_between()
667 .rounded_md()
668 .border_1()
669 .border_color(cx.theme().colors().border)
670 .bg(cx.theme().colors().background)
671 .child(
672 h_flex()
673 .gap_1()
674 .child(Icon::new(IconName::Check).color(Color::Success))
675 .child(Label::new("Authorized")),
676 )
677 .child(
678 Button::new("sign_out", "Sign Out")
679 .label_size(LabelSize::Small)
680 .on_click(|_, window, cx| {
681 window.dispatch_action(copilot::SignOut.boxed_clone(), cx);
682 }),
683 )
684 } else {
685 let loading_icon = Icon::new(IconName::ArrowCircle).with_animation(
686 "arrow-circle",
687 Animation::new(Duration::from_secs(4)).repeat(),
688 |icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
689 );
690
691 const ERROR_LABEL: &str = "Copilot Chat requires an active GitHub Copilot subscription. Please ensure Copilot is configured and try again, or use a different Assistant provider.";
692
693 match &self.copilot_status {
694 Some(status) => match status {
695 Status::Starting { task: _ } => h_flex()
696 .gap_2()
697 .child(loading_icon)
698 .child(Label::new("Starting Copilot…")),
699 Status::SigningIn { prompt: _ }
700 | Status::SignedOut {
701 awaiting_signing_in: true,
702 } => h_flex()
703 .gap_2()
704 .child(loading_icon)
705 .child(Label::new("Signing into Copilot…")),
706 Status::Error(_) => {
707 const LABEL: &str = "Copilot had issues starting. Please try restarting it. If the issue persists, try reinstalling Copilot.";
708 v_flex()
709 .gap_6()
710 .child(Label::new(LABEL))
711 .child(svg().size_8().path(IconName::CopilotError.path()))
712 }
713 _ => {
714 const LABEL: &str = "To use Zed's agent with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription.";
715
716 v_flex().gap_2().child(Label::new(LABEL)).child(
717 Button::new("sign_in", "Sign in to use GitHub Copilot")
718 .icon_color(Color::Muted)
719 .icon(IconName::Github)
720 .icon_position(IconPosition::Start)
721 .icon_size(IconSize::Medium)
722 .full_width()
723 .on_click(|_, window, cx| copilot::initiate_sign_in(window, cx)),
724 )
725 }
726 },
727 None => v_flex().gap_6().child(Label::new(ERROR_LABEL)),
728 }
729 }
730 }
731}