1use crate::{
2 assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
3 LanguageModelRequest,
4};
5use anyhow::{anyhow, Result};
6use client::{proto, Client};
7use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
8use gpui::{AnyView, AppContext, Task};
9use std::{future, sync::Arc};
10use strum::IntoEnumIterator;
11use ui::prelude::*;
12
13pub struct ZedDotDevCompletionProvider {
14 client: Arc<Client>,
15 model: ZedDotDevModel,
16 settings_version: usize,
17 status: client::Status,
18 _maintain_client_status: Task<()>,
19}
20
21impl ZedDotDevCompletionProvider {
22 pub fn new(
23 model: ZedDotDevModel,
24 client: Arc<Client>,
25 settings_version: usize,
26 cx: &mut AppContext,
27 ) -> Self {
28 let mut status_rx = client.status();
29 let status = *status_rx.borrow();
30 let maintain_client_status = cx.spawn(|mut cx| async move {
31 while let Some(status) = status_rx.next().await {
32 let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
33 if let CompletionProvider::ZedDotDev(provider) = provider {
34 provider.status = status;
35 } else {
36 unreachable!()
37 }
38 });
39 }
40 });
41 Self {
42 client,
43 model,
44 settings_version,
45 status,
46 _maintain_client_status: maintain_client_status,
47 }
48 }
49
50 pub fn update(&mut self, model: ZedDotDevModel, settings_version: usize) {
51 self.model = model;
52 self.settings_version = settings_version;
53 }
54
55 pub fn available_models(&self) -> impl Iterator<Item = ZedDotDevModel> {
56 let mut custom_model = if let ZedDotDevModel::Custom(custom_model) = self.model.clone() {
57 Some(custom_model)
58 } else {
59 None
60 };
61 ZedDotDevModel::iter().filter_map(move |model| {
62 if let ZedDotDevModel::Custom(_) = model {
63 Some(ZedDotDevModel::Custom(custom_model.take()?))
64 } else {
65 Some(model)
66 }
67 })
68 }
69
70 pub fn settings_version(&self) -> usize {
71 self.settings_version
72 }
73
74 pub fn model(&self) -> ZedDotDevModel {
75 self.model.clone()
76 }
77
78 pub fn is_authenticated(&self) -> bool {
79 self.status.is_connected()
80 }
81
82 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
83 let client = self.client.clone();
84 cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
85 }
86
87 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
88 cx.new_view(|_cx| AuthenticationPrompt).into()
89 }
90
91 pub fn count_tokens(
92 &self,
93 request: LanguageModelRequest,
94 cx: &AppContext,
95 ) -> BoxFuture<'static, Result<usize>> {
96 match request.model {
97 LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4)
98 | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Turbo)
99 | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni)
100 | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt3Point5Turbo) => {
101 count_open_ai_tokens(request, cx.background_executor())
102 }
103 LanguageModel::ZedDotDev(
104 ZedDotDevModel::Claude3Opus
105 | ZedDotDevModel::Claude3Sonnet
106 | ZedDotDevModel::Claude3Haiku,
107 ) => {
108 // Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
109 count_open_ai_tokens(request, cx.background_executor())
110 }
111 LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => {
112 let request = self.client.request(proto::CountTokensWithLanguageModel {
113 model,
114 messages: request
115 .messages
116 .iter()
117 .map(|message| message.to_proto())
118 .collect(),
119 });
120 async move {
121 let response = request.await?;
122 Ok(response.token_count as usize)
123 }
124 .boxed()
125 }
126 _ => future::ready(Err(anyhow!("invalid model"))).boxed(),
127 }
128 }
129
130 pub fn complete(
131 &self,
132 request: LanguageModelRequest,
133 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
134 let request = proto::CompleteWithLanguageModel {
135 model: request.model.id().to_string(),
136 messages: request
137 .messages
138 .iter()
139 .map(|message| message.to_proto())
140 .collect(),
141 stop: request.stop,
142 temperature: request.temperature,
143 tools: Vec::new(),
144 tool_choice: None,
145 };
146
147 self.client
148 .request_stream(request)
149 .map_ok(|stream| {
150 stream
151 .filter_map(|response| async move {
152 match response {
153 Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
154 Err(error) => Some(Err(error)),
155 }
156 })
157 .boxed()
158 })
159 .boxed()
160 }
161}
162
163struct AuthenticationPrompt;
164
165impl Render for AuthenticationPrompt {
166 fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
167 const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
168
169 v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
170 v_flex()
171 .gap_2()
172 .child(
173 Button::new("sign_in", "Sign in")
174 .icon_color(Color::Muted)
175 .icon(IconName::Github)
176 .icon_position(IconPosition::Start)
177 .style(ButtonStyle::Filled)
178 .full_width()
179 .on_click(|_, cx| {
180 CompletionProvider::global(cx)
181 .authenticate(cx)
182 .detach_and_log_err(cx);
183 }),
184 )
185 .child(
186 div().flex().w_full().items_center().child(
187 Label::new("Sign in to enable collaboration.")
188 .color(Color::Muted)
189 .size(LabelSize::Small),
190 ),
191 ),
192 )
193 }
194}