1pub mod assistant_panel;
2pub mod assistant_settings;
3mod completion_provider;
4mod context;
5pub mod context_store;
6mod inline_assistant;
7mod model_selector;
8mod prompt_library;
9mod prompts;
10mod search;
11mod slash_command;
12mod streaming_diff;
13mod terminal_inline_assistant;
14
15pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
16use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
17use assistant_slash_command::SlashCommandRegistry;
18use client::{proto, Client};
19use command_palette_hooks::CommandPaletteFilter;
20pub use completion_provider::*;
21pub use context::*;
22pub use context_store::*;
23use fs::Fs;
24use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
25use indexed_docs::IndexedDocsRegistry;
26pub(crate) use inline_assistant::*;
27pub(crate) use model_selector::*;
28use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
29use serde::{Deserialize, Serialize};
30use settings::{Settings, SettingsStore};
31use slash_command::{
32 active_command, default_command, diagnostics_command, docs_command, fetch_command,
33 file_command, now_command, project_command, prompt_command, search_command, symbols_command,
34 tabs_command, term_command,
35};
36use std::{
37 fmt::{self, Display},
38 sync::Arc,
39};
40pub(crate) use streaming_diff::*;
41
42actions!(
43 assistant,
44 [
45 Assist,
46 Split,
47 CycleMessageRole,
48 QuoteSelection,
49 InsertIntoEditor,
50 ToggleFocus,
51 ResetKey,
52 InlineAssist,
53 InsertActivePrompt,
54 DeployHistory,
55 DeployPromptLibrary,
56 ApplyEdit,
57 ConfirmCommand,
58 ToggleModelSelector
59 ]
60);
61
62#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
63pub struct MessageId(clock::Lamport);
64
65impl MessageId {
66 pub fn as_u64(self) -> u64 {
67 self.0.as_u64()
68 }
69}
70
71#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
72#[serde(rename_all = "lowercase")]
73pub enum Role {
74 User,
75 Assistant,
76 System,
77}
78
79impl Role {
80 pub fn from_proto(role: i32) -> Role {
81 match proto::LanguageModelRole::from_i32(role) {
82 Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
83 Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
84 Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
85 Some(proto::LanguageModelRole::LanguageModelTool) => Role::System,
86 None => Role::User,
87 }
88 }
89
90 pub fn to_proto(&self) -> proto::LanguageModelRole {
91 match self {
92 Role::User => proto::LanguageModelRole::LanguageModelUser,
93 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
94 Role::System => proto::LanguageModelRole::LanguageModelSystem,
95 }
96 }
97
98 pub fn cycle(self) -> Role {
99 match self {
100 Role::User => Role::Assistant,
101 Role::Assistant => Role::System,
102 Role::System => Role::User,
103 }
104 }
105}
106
107impl Display for Role {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
109 match self {
110 Role::User => write!(f, "user"),
111 Role::Assistant => write!(f, "assistant"),
112 Role::System => write!(f, "system"),
113 }
114 }
115}
116
117#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
118pub enum LanguageModel {
119 Cloud(CloudModel),
120 OpenAi(OpenAiModel),
121 Anthropic(AnthropicModel),
122 Ollama(OllamaModel),
123}
124
125impl Default for LanguageModel {
126 fn default() -> Self {
127 LanguageModel::Cloud(CloudModel::default())
128 }
129}
130
131impl LanguageModel {
132 pub fn telemetry_id(&self) -> String {
133 match self {
134 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
135 LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
136 LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
137 LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
138 }
139 }
140
141 pub fn display_name(&self) -> String {
142 match self {
143 LanguageModel::OpenAi(model) => model.display_name().into(),
144 LanguageModel::Anthropic(model) => model.display_name().into(),
145 LanguageModel::Cloud(model) => model.display_name().into(),
146 LanguageModel::Ollama(model) => model.display_name().into(),
147 }
148 }
149
150 pub fn max_token_count(&self) -> usize {
151 match self {
152 LanguageModel::OpenAi(model) => model.max_token_count(),
153 LanguageModel::Anthropic(model) => model.max_token_count(),
154 LanguageModel::Cloud(model) => model.max_token_count(),
155 LanguageModel::Ollama(model) => model.max_token_count(),
156 }
157 }
158
159 pub fn id(&self) -> &str {
160 match self {
161 LanguageModel::OpenAi(model) => model.id(),
162 LanguageModel::Anthropic(model) => model.id(),
163 LanguageModel::Cloud(model) => model.id(),
164 LanguageModel::Ollama(model) => model.id(),
165 }
166 }
167}
168
169#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
170pub struct LanguageModelRequestMessage {
171 pub role: Role,
172 pub content: String,
173}
174
175impl LanguageModelRequestMessage {
176 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
177 proto::LanguageModelRequestMessage {
178 role: self.role.to_proto() as i32,
179 content: self.content.clone(),
180 tool_calls: Vec::new(),
181 tool_call_id: None,
182 }
183 }
184}
185
186#[derive(Debug, Default, Serialize, Deserialize)]
187pub struct LanguageModelRequest {
188 pub model: LanguageModel,
189 pub messages: Vec<LanguageModelRequestMessage>,
190 pub stop: Vec<String>,
191 pub temperature: f32,
192}
193
194impl LanguageModelRequest {
195 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
196 proto::CompleteWithLanguageModel {
197 model: self.model.id().to_string(),
198 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
199 stop: self.stop.clone(),
200 temperature: self.temperature,
201 tool_choice: None,
202 tools: Vec::new(),
203 }
204 }
205
206 /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
207 pub fn preprocess(&mut self) {
208 match &self.model {
209 LanguageModel::OpenAi(_) => {}
210 LanguageModel::Anthropic(_) => {}
211 LanguageModel::Ollama(_) => {}
212 LanguageModel::Cloud(model) => match model {
213 CloudModel::Claude3Opus
214 | CloudModel::Claude3Sonnet
215 | CloudModel::Claude3Haiku
216 | CloudModel::Claude3_5Sonnet => {
217 preprocess_anthropic_request(self);
218 }
219 _ => {}
220 },
221 }
222 }
223}
224
225#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
226pub struct LanguageModelResponseMessage {
227 pub role: Option<Role>,
228 pub content: Option<String>,
229}
230
231#[derive(Deserialize, Debug)]
232pub struct LanguageModelUsage {
233 pub prompt_tokens: u32,
234 pub completion_tokens: u32,
235 pub total_tokens: u32,
236}
237
238#[derive(Deserialize, Debug)]
239pub struct LanguageModelChoiceDelta {
240 pub index: u32,
241 pub delta: LanguageModelResponseMessage,
242 pub finish_reason: Option<String>,
243}
244
245#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
246pub enum MessageStatus {
247 Pending,
248 Done,
249 Error(SharedString),
250}
251
252impl MessageStatus {
253 pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
254 match status.variant {
255 Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
256 Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
257 Some(proto::context_message_status::Variant::Error(error)) => {
258 MessageStatus::Error(error.message.into())
259 }
260 None => MessageStatus::Pending,
261 }
262 }
263
264 pub fn to_proto(&self) -> proto::ContextMessageStatus {
265 match self {
266 MessageStatus::Pending => proto::ContextMessageStatus {
267 variant: Some(proto::context_message_status::Variant::Pending(
268 proto::context_message_status::Pending {},
269 )),
270 },
271 MessageStatus::Done => proto::ContextMessageStatus {
272 variant: Some(proto::context_message_status::Variant::Done(
273 proto::context_message_status::Done {},
274 )),
275 },
276 MessageStatus::Error(message) => proto::ContextMessageStatus {
277 variant: Some(proto::context_message_status::Variant::Error(
278 proto::context_message_status::Error {
279 message: message.to_string(),
280 },
281 )),
282 },
283 }
284 }
285}
286
287/// The state pertaining to the Assistant.
288#[derive(Default)]
289struct Assistant {
290 /// Whether the Assistant is enabled.
291 enabled: bool,
292}
293
294impl Global for Assistant {}
295
296impl Assistant {
297 const NAMESPACE: &'static str = "assistant";
298
299 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
300 if self.enabled == enabled {
301 return;
302 }
303
304 self.enabled = enabled;
305
306 if !enabled {
307 CommandPaletteFilter::update_global(cx, |filter, _cx| {
308 filter.hide_namespace(Self::NAMESPACE);
309 });
310
311 return;
312 }
313
314 CommandPaletteFilter::update_global(cx, |filter, _cx| {
315 filter.show_namespace(Self::NAMESPACE);
316 });
317 }
318}
319
320pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
321 cx.set_global(Assistant::default());
322 AssistantSettings::register(cx);
323
324 cx.spawn(|mut cx| {
325 let client = client.clone();
326 async move {
327 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
328 let semantic_index = SemanticIndex::new(
329 paths::embeddings_dir().join("semantic-index-db.0.mdb"),
330 Arc::new(embedding_provider),
331 &mut cx,
332 )
333 .await?;
334 cx.update(|cx| cx.set_global(semantic_index))
335 }
336 })
337 .detach();
338
339 context_store::init(&client);
340 prompt_library::init(cx);
341 completion_provider::init(client.clone(), cx);
342 assistant_slash_command::init(cx);
343 register_slash_commands(cx);
344 assistant_panel::init(cx);
345 inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
346 terminal_inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
347 IndexedDocsRegistry::init_global(cx);
348
349 CommandPaletteFilter::update_global(cx, |filter, _cx| {
350 filter.hide_namespace(Assistant::NAMESPACE);
351 });
352 Assistant::update_global(cx, |assistant, cx| {
353 let settings = AssistantSettings::get_global(cx);
354
355 assistant.set_enabled(settings.enabled, cx);
356 });
357 cx.observe_global::<SettingsStore>(|cx| {
358 Assistant::update_global(cx, |assistant, cx| {
359 let settings = AssistantSettings::get_global(cx);
360 assistant.set_enabled(settings.enabled, cx);
361 });
362 })
363 .detach();
364}
365
366fn register_slash_commands(cx: &mut AppContext) {
367 let slash_command_registry = SlashCommandRegistry::global(cx);
368 slash_command_registry.register_command(file_command::FileSlashCommand, true);
369 slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
370 slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
371 slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
372 slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
373 slash_command_registry.register_command(search_command::SearchSlashCommand, true);
374 slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
375 slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
376 slash_command_registry.register_command(term_command::TermSlashCommand, true);
377 slash_command_registry.register_command(now_command::NowSlashCommand, true);
378 slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
379 slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
380 slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
381}
382
383pub fn humanize_token_count(count: usize) -> String {
384 match count {
385 0..=999 => count.to_string(),
386 1000..=9999 => {
387 let thousands = count / 1000;
388 let hundreds = (count % 1000 + 50) / 100;
389 if hundreds == 0 {
390 format!("{}k", thousands)
391 } else if hundreds == 10 {
392 format!("{}k", thousands + 1)
393 } else {
394 format!("{}.{}k", thousands, hundreds)
395 }
396 }
397 _ => format!("{}k", (count + 500) / 1000),
398 }
399}
400
401#[cfg(test)]
402#[ctor::ctor]
403fn init_logger() {
404 if std::env::var("RUST_LOG").is_ok() {
405 env_logger::init();
406 }
407}