1use std::pin::Pin;
2use std::str::FromStr as _;
3use std::sync::Arc;
4
5use anyhow::{Result, anyhow};
6use collections::HashMap;
7use copilot::copilot_chat::{
8 ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, ImageUrl,
9 Model as CopilotChatModel, ModelVendor, Request as CopilotChatRequest, ResponseEvent, Tool,
10 ToolCall,
11};
12use copilot::{Copilot, Status};
13use futures::future::BoxFuture;
14use futures::stream::BoxStream;
15use futures::{FutureExt, Stream, StreamExt};
16use gpui::{
17 Action, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, Render, Subscription, Task,
18 Transformation, percentage, svg,
19};
20use language::language_settings::all_language_settings;
21use language_model::{
22 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
23 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
24 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
25 LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent,
26 LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role,
27 StopReason, TokenUsage,
28};
29use settings::SettingsStore;
30use std::time::Duration;
31use ui::prelude::*;
32use util::debug_panic;
33
34use super::anthropic::count_anthropic_tokens;
35use super::google::count_google_tokens;
36use super::open_ai::count_open_ai_tokens;
37
38const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat");
39const PROVIDER_NAME: LanguageModelProviderName =
40 LanguageModelProviderName::new("GitHub Copilot Chat");
41
42pub struct CopilotChatLanguageModelProvider {
43 state: Entity<State>,
44}
45
46pub struct State {
47 _copilot_chat_subscription: Option<Subscription>,
48 _settings_subscription: Subscription,
49}
50
51impl State {
52 fn is_authenticated(&self, cx: &App) -> bool {
53 CopilotChat::global(cx)
54 .map(|m| m.read(cx).is_authenticated())
55 .unwrap_or(false)
56 }
57}
58
59impl CopilotChatLanguageModelProvider {
60 pub fn new(cx: &mut App) -> Self {
61 let state = cx.new(|cx| {
62 let copilot_chat_subscription = CopilotChat::global(cx)
63 .map(|copilot_chat| cx.observe(&copilot_chat, |_, _, cx| cx.notify()));
64 State {
65 _copilot_chat_subscription: copilot_chat_subscription,
66 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
67 if let Some(copilot_chat) = CopilotChat::global(cx) {
68 let language_settings = all_language_settings(None, cx);
69 let configuration = copilot::copilot_chat::CopilotChatConfiguration {
70 enterprise_uri: language_settings
71 .edit_predictions
72 .copilot
73 .enterprise_uri
74 .clone(),
75 };
76 copilot_chat.update(cx, |chat, cx| {
77 chat.set_configuration(configuration, cx);
78 });
79 }
80 cx.notify();
81 }),
82 }
83 });
84
85 Self { state }
86 }
87
88 fn create_language_model(&self, model: CopilotChatModel) -> Arc<dyn LanguageModel> {
89 Arc::new(CopilotChatLanguageModel {
90 model,
91 request_limiter: RateLimiter::new(4),
92 })
93 }
94}
95
96impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
97 type ObservableEntity = State;
98
99 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
100 Some(self.state.clone())
101 }
102}
103
104impl LanguageModelProvider for CopilotChatLanguageModelProvider {
105 fn id(&self) -> LanguageModelProviderId {
106 PROVIDER_ID
107 }
108
109 fn name(&self) -> LanguageModelProviderName {
110 PROVIDER_NAME
111 }
112
113 fn icon(&self) -> IconName {
114 IconName::Copilot
115 }
116
117 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
118 let models = CopilotChat::global(cx).and_then(|m| m.read(cx).models())?;
119 models
120 .first()
121 .map(|model| self.create_language_model(model.clone()))
122 }
123
124 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
125 // The default model should be Copilot Chat's 'base model', which is likely a relatively fast
126 // model (e.g. 4o) and a sensible choice when considering premium requests
127 self.default_model(cx)
128 }
129
130 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
131 let Some(models) = CopilotChat::global(cx).and_then(|m| m.read(cx).models()) else {
132 return Vec::new();
133 };
134 models
135 .iter()
136 .map(|model| self.create_language_model(model.clone()))
137 .collect()
138 }
139
140 fn is_authenticated(&self, cx: &App) -> bool {
141 self.state.read(cx).is_authenticated(cx)
142 }
143
144 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
145 if self.is_authenticated(cx) {
146 return Task::ready(Ok(()));
147 };
148
149 let Some(copilot) = Copilot::global(cx) else {
150 return Task::ready( Err(anyhow!(
151 "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
152 ).into()));
153 };
154
155 let err = match copilot.read(cx).status() {
156 Status::Authorized => return Task::ready(Ok(())),
157 Status::Disabled => anyhow!(
158 "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
159 ),
160 Status::Error(err) => anyhow!(format!(
161 "Received the following error while signing into Copilot: {err}"
162 )),
163 Status::Starting { task: _ } => anyhow!(
164 "Copilot is still starting, please wait for Copilot to start then try again"
165 ),
166 Status::Unauthorized => anyhow!(
167 "Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."
168 ),
169 Status::SignedOut { .. } => {
170 anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again.")
171 }
172 Status::SigningIn { prompt: _ } => anyhow!("Still signing into Copilot..."),
173 };
174
175 Task::ready(Err(err.into()))
176 }
177
178 fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
179 let state = self.state.clone();
180 cx.new(|cx| ConfigurationView::new(state, cx)).into()
181 }
182
183 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
184 Task::ready(Err(anyhow!(
185 "Signing out of GitHub Copilot Chat is currently not supported."
186 )))
187 }
188}
189
190pub struct CopilotChatLanguageModel {
191 model: CopilotChatModel,
192 request_limiter: RateLimiter,
193}
194
195impl LanguageModel for CopilotChatLanguageModel {
196 fn id(&self) -> LanguageModelId {
197 LanguageModelId::from(self.model.id().to_string())
198 }
199
200 fn name(&self) -> LanguageModelName {
201 LanguageModelName::from(self.model.display_name().to_string())
202 }
203
204 fn provider_id(&self) -> LanguageModelProviderId {
205 PROVIDER_ID
206 }
207
208 fn provider_name(&self) -> LanguageModelProviderName {
209 PROVIDER_NAME
210 }
211
212 fn supports_tools(&self) -> bool {
213 self.model.supports_tools()
214 }
215
216 fn supports_images(&self) -> bool {
217 self.model.supports_vision()
218 }
219
220 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
221 match self.model.vendor() {
222 ModelVendor::OpenAI | ModelVendor::Anthropic => {
223 LanguageModelToolSchemaFormat::JsonSchema
224 }
225 ModelVendor::Google => LanguageModelToolSchemaFormat::JsonSchemaSubset,
226 }
227 }
228
229 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
230 match choice {
231 LanguageModelToolChoice::Auto
232 | LanguageModelToolChoice::Any
233 | LanguageModelToolChoice::None => self.supports_tools(),
234 }
235 }
236
237 fn telemetry_id(&self) -> String {
238 format!("copilot_chat/{}", self.model.id())
239 }
240
241 fn max_token_count(&self) -> u64 {
242 self.model.max_token_count()
243 }
244
245 fn count_tokens(
246 &self,
247 request: LanguageModelRequest,
248 cx: &App,
249 ) -> BoxFuture<'static, Result<u64>> {
250 match self.model.vendor() {
251 ModelVendor::Anthropic => count_anthropic_tokens(request, cx),
252 ModelVendor::Google => count_google_tokens(request, cx),
253 ModelVendor::OpenAI => {
254 let model = open_ai::Model::from_id(self.model.id()).unwrap_or_default();
255 count_open_ai_tokens(request, model, cx)
256 }
257 }
258 }
259
260 fn stream_completion(
261 &self,
262 request: LanguageModelRequest,
263 cx: &AsyncApp,
264 ) -> BoxFuture<
265 'static,
266 Result<
267 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
268 LanguageModelCompletionError,
269 >,
270 > {
271 let copilot_request = match into_copilot_chat(&self.model, request) {
272 Ok(request) => request,
273 Err(err) => return futures::future::ready(Err(err.into())).boxed(),
274 };
275 let is_streaming = copilot_request.stream;
276
277 let request_limiter = self.request_limiter.clone();
278 let future = cx.spawn(async move |cx| {
279 let request = CopilotChat::stream_completion(copilot_request, cx.clone());
280 request_limiter
281 .stream(async move {
282 let response = request.await?;
283 Ok(map_to_language_model_completion_events(
284 response,
285 is_streaming,
286 ))
287 })
288 .await
289 });
290 async move { Ok(future.await?.boxed()) }.boxed()
291 }
292}
293
294pub fn map_to_language_model_completion_events(
295 events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
296 is_streaming: bool,
297) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
298 #[derive(Default)]
299 struct RawToolCall {
300 id: String,
301 name: String,
302 arguments: String,
303 }
304
305 struct State {
306 events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
307 tool_calls_by_index: HashMap<usize, RawToolCall>,
308 }
309
310 futures::stream::unfold(
311 State {
312 events,
313 tool_calls_by_index: HashMap::default(),
314 },
315 move |mut state| async move {
316 if let Some(event) = state.events.next().await {
317 match event {
318 Ok(event) => {
319 let Some(choice) = event.choices.first() else {
320 return Some((
321 vec![Err(anyhow!("Response contained no choices").into())],
322 state,
323 ));
324 };
325
326 let delta = if is_streaming {
327 choice.delta.as_ref()
328 } else {
329 choice.message.as_ref()
330 };
331
332 let Some(delta) = delta else {
333 return Some((
334 vec![Err(anyhow!("Response contained no delta").into())],
335 state,
336 ));
337 };
338
339 let mut events = Vec::new();
340 if let Some(content) = delta.content.clone() {
341 events.push(Ok(LanguageModelCompletionEvent::Text(content)));
342 }
343
344 for tool_call in &delta.tool_calls {
345 let entry = state
346 .tool_calls_by_index
347 .entry(tool_call.index)
348 .or_default();
349
350 if let Some(tool_id) = tool_call.id.clone() {
351 entry.id = tool_id;
352 }
353
354 if let Some(function) = tool_call.function.as_ref() {
355 if let Some(name) = function.name.clone() {
356 entry.name = name;
357 }
358
359 if let Some(arguments) = function.arguments.clone() {
360 entry.arguments.push_str(&arguments);
361 }
362 }
363 }
364
365 if let Some(usage) = event.usage {
366 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
367 TokenUsage {
368 input_tokens: usage.prompt_tokens,
369 output_tokens: usage.completion_tokens,
370 cache_creation_input_tokens: 0,
371 cache_read_input_tokens: 0,
372 },
373 )));
374 }
375
376 match choice.finish_reason.as_deref() {
377 Some("stop") => {
378 events.push(Ok(LanguageModelCompletionEvent::Stop(
379 StopReason::EndTurn,
380 )));
381 }
382 Some("tool_calls") => {
383 events.extend(state.tool_calls_by_index.drain().map(
384 |(_, tool_call)| {
385 // The model can output an empty string
386 // to indicate the absence of arguments.
387 // When that happens, create an empty
388 // object instead.
389 let arguments = if tool_call.arguments.is_empty() {
390 Ok(serde_json::Value::Object(Default::default()))
391 } else {
392 serde_json::Value::from_str(&tool_call.arguments)
393 };
394 match arguments {
395 Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
396 LanguageModelToolUse {
397 id: tool_call.id.clone().into(),
398 name: tool_call.name.as_str().into(),
399 is_input_complete: true,
400 input,
401 raw_input: tool_call.arguments.clone(),
402 },
403 )),
404 Err(error) => Ok(
405 LanguageModelCompletionEvent::ToolUseJsonParseError {
406 id: tool_call.id.into(),
407 tool_name: tool_call.name.as_str().into(),
408 raw_input: tool_call.arguments.into(),
409 json_parse_error: error.to_string(),
410 },
411 ),
412 }
413 },
414 ));
415
416 events.push(Ok(LanguageModelCompletionEvent::Stop(
417 StopReason::ToolUse,
418 )));
419 }
420 Some(stop_reason) => {
421 log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}");
422 events.push(Ok(LanguageModelCompletionEvent::Stop(
423 StopReason::EndTurn,
424 )));
425 }
426 None => {}
427 }
428
429 return Some((events, state));
430 }
431 Err(err) => return Some((vec![Err(anyhow!(err).into())], state)),
432 }
433 }
434
435 None
436 },
437 )
438 .flat_map(futures::stream::iter)
439}
440
441fn into_copilot_chat(
442 model: &copilot::copilot_chat::Model,
443 request: LanguageModelRequest,
444) -> Result<CopilotChatRequest> {
445 let mut request_messages: Vec<LanguageModelRequestMessage> = Vec::new();
446 for message in request.messages {
447 if let Some(last_message) = request_messages.last_mut() {
448 if last_message.role == message.role {
449 last_message.content.extend(message.content);
450 } else {
451 request_messages.push(message);
452 }
453 } else {
454 request_messages.push(message);
455 }
456 }
457
458 let mut tool_called = false;
459 let mut messages: Vec<ChatMessage> = Vec::new();
460 for message in request_messages {
461 match message.role {
462 Role::User => {
463 for content in &message.content {
464 if let MessageContent::ToolResult(tool_result) = content {
465 let content = match &tool_result.content {
466 LanguageModelToolResultContent::Text(text) => text.to_string().into(),
467 LanguageModelToolResultContent::Image(image) => {
468 if model.supports_vision() {
469 ChatMessageContent::Multipart(vec![ChatMessagePart::Image {
470 image_url: ImageUrl {
471 url: image.to_base64_url(),
472 },
473 }])
474 } else {
475 debug_panic!(
476 "This should be caught at {} level",
477 tool_result.tool_name
478 );
479 "[Tool responded with an image, but this model does not support vision]".to_string().into()
480 }
481 }
482 };
483
484 messages.push(ChatMessage::Tool {
485 tool_call_id: tool_result.tool_use_id.to_string(),
486 content,
487 });
488 }
489 }
490
491 let mut content_parts = Vec::new();
492 for content in &message.content {
493 match content {
494 MessageContent::Text(text) | MessageContent::Thinking { text, .. }
495 if !text.is_empty() =>
496 {
497 if let Some(ChatMessagePart::Text { text: text_content }) =
498 content_parts.last_mut()
499 {
500 text_content.push_str(text);
501 } else {
502 content_parts.push(ChatMessagePart::Text {
503 text: text.to_string(),
504 });
505 }
506 }
507 MessageContent::Image(image) if model.supports_vision() => {
508 content_parts.push(ChatMessagePart::Image {
509 image_url: ImageUrl {
510 url: image.to_base64_url(),
511 },
512 });
513 }
514 _ => {}
515 }
516 }
517
518 if !content_parts.is_empty() {
519 messages.push(ChatMessage::User {
520 content: content_parts.into(),
521 });
522 }
523 }
524 Role::Assistant => {
525 let mut tool_calls = Vec::new();
526 for content in &message.content {
527 if let MessageContent::ToolUse(tool_use) = content {
528 tool_called = true;
529 tool_calls.push(ToolCall {
530 id: tool_use.id.to_string(),
531 content: copilot::copilot_chat::ToolCallContent::Function {
532 function: copilot::copilot_chat::FunctionContent {
533 name: tool_use.name.to_string(),
534 arguments: serde_json::to_string(&tool_use.input)?,
535 },
536 },
537 });
538 }
539 }
540
541 let text_content = {
542 let mut buffer = String::new();
543 for string in message.content.iter().filter_map(|content| match content {
544 MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
545 Some(text.as_str())
546 }
547 MessageContent::ToolUse(_)
548 | MessageContent::RedactedThinking(_)
549 | MessageContent::ToolResult(_)
550 | MessageContent::Image(_) => None,
551 }) {
552 buffer.push_str(string);
553 }
554
555 buffer
556 };
557
558 messages.push(ChatMessage::Assistant {
559 content: if text_content.is_empty() {
560 ChatMessageContent::empty()
561 } else {
562 text_content.into()
563 },
564 tool_calls,
565 });
566 }
567 Role::System => messages.push(ChatMessage::System {
568 content: message.string_contents(),
569 }),
570 }
571 }
572
573 let mut tools = request
574 .tools
575 .iter()
576 .map(|tool| Tool::Function {
577 function: copilot::copilot_chat::Function {
578 name: tool.name.clone(),
579 description: tool.description.clone(),
580 parameters: tool.input_schema.clone(),
581 },
582 })
583 .collect::<Vec<_>>();
584
585 // The API will return a Bad Request (with no error message) when tools
586 // were used previously in the conversation but no tools are provided as
587 // part of this request. Inserting a dummy tool seems to circumvent this
588 // error.
589 if tool_called && tools.is_empty() {
590 tools.push(Tool::Function {
591 function: copilot::copilot_chat::Function {
592 name: "noop".to_string(),
593 description: "No operation".to_string(),
594 parameters: serde_json::json!({
595 "type": "object"
596 }),
597 },
598 });
599 }
600
601 Ok(CopilotChatRequest {
602 intent: true,
603 n: 1,
604 stream: model.uses_streaming(),
605 temperature: 0.1,
606 model: model.id().to_string(),
607 messages,
608 tools,
609 tool_choice: request.tool_choice.map(|choice| match choice {
610 LanguageModelToolChoice::Auto => copilot::copilot_chat::ToolChoice::Auto,
611 LanguageModelToolChoice::Any => copilot::copilot_chat::ToolChoice::Any,
612 LanguageModelToolChoice::None => copilot::copilot_chat::ToolChoice::None,
613 }),
614 })
615}
616
617struct ConfigurationView {
618 copilot_status: Option<copilot::Status>,
619 state: Entity<State>,
620 _subscription: Option<Subscription>,
621}
622
623impl ConfigurationView {
624 pub fn new(state: Entity<State>, cx: &mut Context<Self>) -> Self {
625 let copilot = Copilot::global(cx);
626
627 Self {
628 copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()),
629 state,
630 _subscription: copilot.as_ref().map(|copilot| {
631 cx.observe(copilot, |this, model, cx| {
632 this.copilot_status = Some(model.read(cx).status());
633 cx.notify();
634 })
635 }),
636 }
637 }
638}
639
640impl Render for ConfigurationView {
641 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
642 if self.state.read(cx).is_authenticated(cx) {
643 h_flex()
644 .mt_1()
645 .p_1()
646 .justify_between()
647 .rounded_md()
648 .border_1()
649 .border_color(cx.theme().colors().border)
650 .bg(cx.theme().colors().background)
651 .child(
652 h_flex()
653 .gap_1()
654 .child(Icon::new(IconName::Check).color(Color::Success))
655 .child(Label::new("Authorized")),
656 )
657 .child(
658 Button::new("sign_out", "Sign Out")
659 .label_size(LabelSize::Small)
660 .on_click(|_, window, cx| {
661 window.dispatch_action(copilot::SignOut.boxed_clone(), cx);
662 }),
663 )
664 } else {
665 let loading_icon = Icon::new(IconName::ArrowCircle).with_animation(
666 "arrow-circle",
667 Animation::new(Duration::from_secs(4)).repeat(),
668 |icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
669 );
670
671 const ERROR_LABEL: &str = "Copilot Chat requires an active GitHub Copilot subscription. Please ensure Copilot is configured and try again, or use a different Assistant provider.";
672
673 match &self.copilot_status {
674 Some(status) => match status {
675 Status::Starting { task: _ } => h_flex()
676 .gap_2()
677 .child(loading_icon)
678 .child(Label::new("Starting Copilot…")),
679 Status::SigningIn { prompt: _ }
680 | Status::SignedOut {
681 awaiting_signing_in: true,
682 } => h_flex()
683 .gap_2()
684 .child(loading_icon)
685 .child(Label::new("Signing into Copilot…")),
686 Status::Error(_) => {
687 const LABEL: &str = "Copilot had issues starting. Please try restarting it. If the issue persists, try reinstalling Copilot.";
688 v_flex()
689 .gap_6()
690 .child(Label::new(LABEL))
691 .child(svg().size_8().path(IconName::CopilotError.path()))
692 }
693 _ => {
694 const LABEL: &str = "To use Zed's assistant with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription.";
695 v_flex().gap_2().child(Label::new(LABEL)).child(
696 Button::new("sign_in", "Sign in to use GitHub Copilot")
697 .icon_color(Color::Muted)
698 .icon(IconName::Github)
699 .icon_position(IconPosition::Start)
700 .icon_size(IconSize::Medium)
701 .full_width()
702 .on_click(|_, window, cx| copilot::initiate_sign_in(window, cx)),
703 )
704 }
705 },
706 None => v_flex().gap_6().child(Label::new(ERROR_LABEL)),
707 }
708 }
709 }
710}