1use std::pin::Pin;
2use std::str::FromStr as _;
3use std::sync::Arc;
4
5use anyhow::{Result, anyhow};
6use collections::HashMap;
7use copilot::copilot_chat::{
8 ChatMessage, ChatMessageContent, CopilotChat, ImageUrl, Model as CopilotChatModel, ModelVendor,
9 Request as CopilotChatRequest, ResponseEvent, Tool, ToolCall,
10};
11use copilot::{Copilot, Status};
12use futures::future::BoxFuture;
13use futures::stream::BoxStream;
14use futures::{FutureExt, Stream, StreamExt};
15use gpui::{
16 Action, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, Render, Subscription, Task,
17 Transformation, percentage, svg,
18};
19use language_model::{
20 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
21 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
22 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
23 LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolSchemaFormat,
24 LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
25};
26use settings::SettingsStore;
27use std::time::Duration;
28use ui::prelude::*;
29
30use super::anthropic::count_anthropic_tokens;
31use super::google::count_google_tokens;
32use super::open_ai::count_open_ai_tokens;
33
34const PROVIDER_ID: &str = "copilot_chat";
35const PROVIDER_NAME: &str = "GitHub Copilot Chat";
36
37#[derive(Default, Clone, Debug, PartialEq)]
38pub struct CopilotChatSettings {}
39
40pub struct CopilotChatLanguageModelProvider {
41 state: Entity<State>,
42}
43
44pub struct State {
45 _copilot_chat_subscription: Option<Subscription>,
46 _settings_subscription: Subscription,
47}
48
49impl State {
50 fn is_authenticated(&self, cx: &App) -> bool {
51 CopilotChat::global(cx)
52 .map(|m| m.read(cx).is_authenticated())
53 .unwrap_or(false)
54 }
55}
56
57impl CopilotChatLanguageModelProvider {
58 pub fn new(cx: &mut App) -> Self {
59 let state = cx.new(|cx| {
60 let _copilot_chat_subscription = CopilotChat::global(cx)
61 .map(|copilot_chat| cx.observe(&copilot_chat, |_, _, cx| cx.notify()));
62 State {
63 _copilot_chat_subscription,
64 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
65 cx.notify();
66 }),
67 }
68 });
69
70 Self { state }
71 }
72
73 fn create_language_model(&self, model: CopilotChatModel) -> Arc<dyn LanguageModel> {
74 Arc::new(CopilotChatLanguageModel {
75 model,
76 request_limiter: RateLimiter::new(4),
77 })
78 }
79}
80
81impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
82 type ObservableEntity = State;
83
84 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
85 Some(self.state.clone())
86 }
87}
88
89impl LanguageModelProvider for CopilotChatLanguageModelProvider {
90 fn id(&self) -> LanguageModelProviderId {
91 LanguageModelProviderId(PROVIDER_ID.into())
92 }
93
94 fn name(&self) -> LanguageModelProviderName {
95 LanguageModelProviderName(PROVIDER_NAME.into())
96 }
97
98 fn icon(&self) -> IconName {
99 IconName::Copilot
100 }
101
102 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
103 let models = CopilotChat::global(cx).and_then(|m| m.read(cx).models())?;
104 models
105 .first()
106 .map(|model| self.create_language_model(model.clone()))
107 }
108
109 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
110 // The default model should be Copilot Chat's 'base model', which is likely a relatively fast
111 // model (e.g. 4o) and a sensible choice when considering premium requests
112 self.default_model(cx)
113 }
114
115 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
116 let Some(models) = CopilotChat::global(cx).and_then(|m| m.read(cx).models()) else {
117 return Vec::new();
118 };
119 models
120 .iter()
121 .map(|model| self.create_language_model(model.clone()))
122 .collect()
123 }
124
125 fn is_authenticated(&self, cx: &App) -> bool {
126 self.state.read(cx).is_authenticated(cx)
127 }
128
129 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
130 if self.is_authenticated(cx) {
131 return Task::ready(Ok(()));
132 };
133
134 let Some(copilot) = Copilot::global(cx) else {
135 return Task::ready( Err(anyhow!(
136 "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
137 ).into()));
138 };
139
140 let err = match copilot.read(cx).status() {
141 Status::Authorized => return Task::ready(Ok(())),
142 Status::Disabled => anyhow!(
143 "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
144 ),
145 Status::Error(err) => anyhow!(format!(
146 "Received the following error while signing into Copilot: {err}"
147 )),
148 Status::Starting { task: _ } => anyhow!(
149 "Copilot is still starting, please wait for Copilot to start then try again"
150 ),
151 Status::Unauthorized => anyhow!(
152 "Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."
153 ),
154 Status::SignedOut { .. } => {
155 anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again.")
156 }
157 Status::SigningIn { prompt: _ } => anyhow!("Still signing into Copilot..."),
158 };
159
160 Task::ready(Err(err.into()))
161 }
162
163 fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
164 let state = self.state.clone();
165 cx.new(|cx| ConfigurationView::new(state, cx)).into()
166 }
167
168 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
169 Task::ready(Err(anyhow!(
170 "Signing out of GitHub Copilot Chat is currently not supported."
171 )))
172 }
173}
174
175pub struct CopilotChatLanguageModel {
176 model: CopilotChatModel,
177 request_limiter: RateLimiter,
178}
179
180impl LanguageModel for CopilotChatLanguageModel {
181 fn id(&self) -> LanguageModelId {
182 LanguageModelId::from(self.model.id().to_string())
183 }
184
185 fn name(&self) -> LanguageModelName {
186 LanguageModelName::from(self.model.display_name().to_string())
187 }
188
189 fn provider_id(&self) -> LanguageModelProviderId {
190 LanguageModelProviderId(PROVIDER_ID.into())
191 }
192
193 fn provider_name(&self) -> LanguageModelProviderName {
194 LanguageModelProviderName(PROVIDER_NAME.into())
195 }
196
197 fn supports_tools(&self) -> bool {
198 self.model.supports_tools()
199 }
200
201 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
202 match self.model.vendor() {
203 ModelVendor::OpenAI | ModelVendor::Anthropic => {
204 LanguageModelToolSchemaFormat::JsonSchema
205 }
206 ModelVendor::Google => LanguageModelToolSchemaFormat::JsonSchemaSubset,
207 }
208 }
209
210 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
211 match choice {
212 LanguageModelToolChoice::Auto
213 | LanguageModelToolChoice::Any
214 | LanguageModelToolChoice::None => self.supports_tools(),
215 }
216 }
217
218 fn telemetry_id(&self) -> String {
219 format!("copilot_chat/{}", self.model.id())
220 }
221
222 fn max_token_count(&self) -> usize {
223 self.model.max_token_count()
224 }
225
226 fn count_tokens(
227 &self,
228 request: LanguageModelRequest,
229 cx: &App,
230 ) -> BoxFuture<'static, Result<usize>> {
231 match self.model.vendor() {
232 ModelVendor::Anthropic => count_anthropic_tokens(request, cx),
233 ModelVendor::Google => count_google_tokens(request, cx),
234 ModelVendor::OpenAI => {
235 let model = open_ai::Model::from_id(self.model.id()).unwrap_or_default();
236 count_open_ai_tokens(request, model, cx)
237 }
238 }
239 }
240
241 fn stream_completion(
242 &self,
243 request: LanguageModelRequest,
244 cx: &AsyncApp,
245 ) -> BoxFuture<
246 'static,
247 Result<
248 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
249 >,
250 > {
251 if let Some(message) = request.messages.last() {
252 if message.contents_empty() {
253 const EMPTY_PROMPT_MSG: &str =
254 "Empty prompts aren't allowed. Please provide a non-empty prompt.";
255 return futures::future::ready(Err(anyhow::anyhow!(EMPTY_PROMPT_MSG))).boxed();
256 }
257
258 // Copilot Chat has a restriction that the final message must be from the user.
259 // While their API does return an error message for this, we can catch it earlier
260 // and provide a more helpful error message.
261 if !matches!(message.role, Role::User) {
262 const USER_ROLE_MSG: &str = "The final message must be from the user. To provide a system prompt, you must provide the system prompt followed by a user prompt.";
263 return futures::future::ready(Err(anyhow::anyhow!(USER_ROLE_MSG))).boxed();
264 }
265 }
266
267 let copilot_request = match self.to_copilot_chat_request(request) {
268 Ok(request) => request,
269 Err(err) => return futures::future::ready(Err(err)).boxed(),
270 };
271 let is_streaming = copilot_request.stream;
272
273 let request_limiter = self.request_limiter.clone();
274 let future = cx.spawn(async move |cx| {
275 let request = CopilotChat::stream_completion(copilot_request, cx.clone());
276 request_limiter
277 .stream(async move {
278 let response = request.await?;
279 Ok(map_to_language_model_completion_events(
280 response,
281 is_streaming,
282 ))
283 })
284 .await
285 });
286 async move { Ok(future.await?.boxed()) }.boxed()
287 }
288}
289
290pub fn map_to_language_model_completion_events(
291 events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
292 is_streaming: bool,
293) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
294 #[derive(Default)]
295 struct RawToolCall {
296 id: String,
297 name: String,
298 arguments: String,
299 }
300
301 struct State {
302 events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
303 tool_calls_by_index: HashMap<usize, RawToolCall>,
304 }
305
306 futures::stream::unfold(
307 State {
308 events,
309 tool_calls_by_index: HashMap::default(),
310 },
311 move |mut state| async move {
312 if let Some(event) = state.events.next().await {
313 match event {
314 Ok(event) => {
315 let Some(choice) = event.choices.first() else {
316 return Some((
317 vec![Err(anyhow!("Response contained no choices").into())],
318 state,
319 ));
320 };
321
322 let delta = if is_streaming {
323 choice.delta.as_ref()
324 } else {
325 choice.message.as_ref()
326 };
327
328 let Some(delta) = delta else {
329 return Some((
330 vec![Err(anyhow!("Response contained no delta").into())],
331 state,
332 ));
333 };
334
335 let mut events = Vec::new();
336 if let Some(content) = delta.content.clone() {
337 events.push(Ok(LanguageModelCompletionEvent::Text(content)));
338 }
339
340 for tool_call in &delta.tool_calls {
341 let entry = state
342 .tool_calls_by_index
343 .entry(tool_call.index)
344 .or_default();
345
346 if let Some(tool_id) = tool_call.id.clone() {
347 entry.id = tool_id;
348 }
349
350 if let Some(function) = tool_call.function.as_ref() {
351 if let Some(name) = function.name.clone() {
352 entry.name = name;
353 }
354
355 if let Some(arguments) = function.arguments.clone() {
356 entry.arguments.push_str(&arguments);
357 }
358 }
359 }
360
361 match choice.finish_reason.as_deref() {
362 Some("stop") => {
363 events.push(Ok(LanguageModelCompletionEvent::Stop(
364 StopReason::EndTurn,
365 )));
366 }
367 Some("tool_calls") => {
368 events.extend(state.tool_calls_by_index.drain().map(
369 |(_, tool_call)| {
370 // The model can output an empty string
371 // to indicate the absence of arguments.
372 // When that happens, create an empty
373 // object instead.
374 let arguments = if tool_call.arguments.is_empty() {
375 Ok(serde_json::Value::Object(Default::default()))
376 } else {
377 serde_json::Value::from_str(&tool_call.arguments)
378 };
379 match arguments {
380 Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
381 LanguageModelToolUse {
382 id: tool_call.id.clone().into(),
383 name: tool_call.name.as_str().into(),
384 is_input_complete: true,
385 input,
386 raw_input: tool_call.arguments.clone(),
387 },
388 )),
389 Err(error) => {
390 Err(LanguageModelCompletionError::BadInputJson {
391 id: tool_call.id.into(),
392 tool_name: tool_call.name.as_str().into(),
393 raw_input: tool_call.arguments.into(),
394 json_parse_error: error.to_string(),
395 })
396 }
397 }
398 },
399 ));
400
401 events.push(Ok(LanguageModelCompletionEvent::Stop(
402 StopReason::ToolUse,
403 )));
404 }
405 Some(stop_reason) => {
406 log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}");
407 events.push(Ok(LanguageModelCompletionEvent::Stop(
408 StopReason::EndTurn,
409 )));
410 }
411 None => {}
412 }
413
414 return Some((events, state));
415 }
416 Err(err) => return Some((vec![Err(anyhow!(err).into())], state)),
417 }
418 }
419
420 None
421 },
422 )
423 .flat_map(futures::stream::iter)
424}
425
426impl CopilotChatLanguageModel {
427 pub fn to_copilot_chat_request(
428 &self,
429 request: LanguageModelRequest,
430 ) -> Result<CopilotChatRequest> {
431 let mut request_messages: Vec<LanguageModelRequestMessage> = Vec::new();
432 for message in request.messages {
433 if let Some(last_message) = request_messages.last_mut() {
434 if last_message.role == message.role {
435 last_message.content.extend(message.content);
436 } else {
437 request_messages.push(message);
438 }
439 } else {
440 request_messages.push(message);
441 }
442 }
443
444 let mut tool_called = false;
445 let mut messages: Vec<ChatMessage> = Vec::new();
446 for message in request_messages {
447 match message.role {
448 Role::User => {
449 for content in &message.content {
450 if let MessageContent::ToolResult(tool_result) = content {
451 messages.push(ChatMessage::Tool {
452 tool_call_id: tool_result.tool_use_id.to_string(),
453 content: tool_result.content.to_string(),
454 });
455 }
456 }
457
458 let mut content_parts = Vec::new();
459 for content in &message.content {
460 match content {
461 MessageContent::Text(text) | MessageContent::Thinking { text, .. }
462 if !text.is_empty() =>
463 {
464 if let Some(ChatMessageContent::Text { text: text_content }) =
465 content_parts.last_mut()
466 {
467 text_content.push_str(text);
468 } else {
469 content_parts.push(ChatMessageContent::Text {
470 text: text.to_string(),
471 });
472 }
473 }
474 MessageContent::Image(image) if self.model.supports_vision() => {
475 content_parts.push(ChatMessageContent::Image {
476 image_url: ImageUrl {
477 url: image.to_base64_url(),
478 },
479 });
480 }
481 _ => {}
482 }
483 }
484
485 if !content_parts.is_empty() {
486 messages.push(ChatMessage::User {
487 content: content_parts,
488 });
489 }
490 }
491 Role::Assistant => {
492 let mut tool_calls = Vec::new();
493 for content in &message.content {
494 if let MessageContent::ToolUse(tool_use) = content {
495 tool_called = true;
496 tool_calls.push(ToolCall {
497 id: tool_use.id.to_string(),
498 content: copilot::copilot_chat::ToolCallContent::Function {
499 function: copilot::copilot_chat::FunctionContent {
500 name: tool_use.name.to_string(),
501 arguments: serde_json::to_string(&tool_use.input)?,
502 },
503 },
504 });
505 }
506 }
507
508 let text_content = {
509 let mut buffer = String::new();
510 for string in message.content.iter().filter_map(|content| match content {
511 MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
512 Some(text.as_str())
513 }
514 MessageContent::ToolUse(_)
515 | MessageContent::RedactedThinking(_)
516 | MessageContent::ToolResult(_)
517 | MessageContent::Image(_) => None,
518 }) {
519 buffer.push_str(string);
520 }
521
522 buffer
523 };
524
525 messages.push(ChatMessage::Assistant {
526 content: if text_content.is_empty() {
527 None
528 } else {
529 Some(text_content)
530 },
531 tool_calls,
532 });
533 }
534 Role::System => messages.push(ChatMessage::System {
535 content: message.string_contents(),
536 }),
537 }
538 }
539
540 let mut tools = request
541 .tools
542 .iter()
543 .map(|tool| Tool::Function {
544 function: copilot::copilot_chat::Function {
545 name: tool.name.clone(),
546 description: tool.description.clone(),
547 parameters: tool.input_schema.clone(),
548 },
549 })
550 .collect::<Vec<_>>();
551
552 // The API will return a Bad Request (with no error message) when tools
553 // were used previously in the conversation but no tools are provided as
554 // part of this request. Inserting a dummy tool seems to circumvent this
555 // error.
556 if tool_called && tools.is_empty() {
557 tools.push(Tool::Function {
558 function: copilot::copilot_chat::Function {
559 name: "noop".to_string(),
560 description: "No operation".to_string(),
561 parameters: serde_json::json!({
562 "type": "object"
563 }),
564 },
565 });
566 }
567
568 Ok(CopilotChatRequest {
569 intent: true,
570 n: 1,
571 stream: self.model.uses_streaming(),
572 temperature: 0.1,
573 model: self.model.id().to_string(),
574 messages,
575 tools,
576 tool_choice: request.tool_choice.map(|choice| match choice {
577 LanguageModelToolChoice::Auto => copilot::copilot_chat::ToolChoice::Auto,
578 LanguageModelToolChoice::Any => copilot::copilot_chat::ToolChoice::Any,
579 LanguageModelToolChoice::None => copilot::copilot_chat::ToolChoice::None,
580 }),
581 })
582 }
583}
584
585struct ConfigurationView {
586 copilot_status: Option<copilot::Status>,
587 state: Entity<State>,
588 _subscription: Option<Subscription>,
589}
590
591impl ConfigurationView {
592 pub fn new(state: Entity<State>, cx: &mut Context<Self>) -> Self {
593 let copilot = Copilot::global(cx);
594
595 Self {
596 copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()),
597 state,
598 _subscription: copilot.as_ref().map(|copilot| {
599 cx.observe(copilot, |this, model, cx| {
600 this.copilot_status = Some(model.read(cx).status());
601 cx.notify();
602 })
603 }),
604 }
605 }
606}
607
608impl Render for ConfigurationView {
609 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
610 if self.state.read(cx).is_authenticated(cx) {
611 h_flex()
612 .mt_1()
613 .p_1()
614 .justify_between()
615 .rounded_md()
616 .border_1()
617 .border_color(cx.theme().colors().border)
618 .bg(cx.theme().colors().background)
619 .child(
620 h_flex()
621 .gap_1()
622 .child(Icon::new(IconName::Check).color(Color::Success))
623 .child(Label::new("Authorized")),
624 )
625 .child(
626 Button::new("sign_out", "Sign Out")
627 .label_size(LabelSize::Small)
628 .on_click(|_, window, cx| {
629 window.dispatch_action(copilot::SignOut.boxed_clone(), cx);
630 }),
631 )
632 } else {
633 let loading_icon = Icon::new(IconName::ArrowCircle).with_animation(
634 "arrow-circle",
635 Animation::new(Duration::from_secs(4)).repeat(),
636 |icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
637 );
638
639 const ERROR_LABEL: &str = "Copilot Chat requires an active GitHub Copilot subscription. Please ensure Copilot is configured and try again, or use a different Assistant provider.";
640
641 match &self.copilot_status {
642 Some(status) => match status {
643 Status::Starting { task: _ } => h_flex()
644 .gap_2()
645 .child(loading_icon)
646 .child(Label::new("Starting Copilot…")),
647 Status::SigningIn { prompt: _ }
648 | Status::SignedOut {
649 awaiting_signing_in: true,
650 } => h_flex()
651 .gap_2()
652 .child(loading_icon)
653 .child(Label::new("Signing into Copilot…")),
654 Status::Error(_) => {
655 const LABEL: &str = "Copilot had issues starting. Please try restarting it. If the issue persists, try reinstalling Copilot.";
656 v_flex()
657 .gap_6()
658 .child(Label::new(LABEL))
659 .child(svg().size_8().path(IconName::CopilotError.path()))
660 }
661 _ => {
662 const LABEL: &str = "To use Zed's assistant with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription.";
663 v_flex().gap_2().child(Label::new(LABEL)).child(
664 Button::new("sign_in", "Sign in to use GitHub Copilot")
665 .icon_color(Color::Muted)
666 .icon(IconName::Github)
667 .icon_position(IconPosition::Start)
668 .icon_size(IconSize::Medium)
669 .full_width()
670 .on_click(|_, window, cx| copilot::initiate_sign_in(window, cx)),
671 )
672 }
673 },
674 None => v_flex().gap_6().child(Label::new(ERROR_LABEL)),
675 }
676 }
677 }
678}