1pub mod assistant_panel;
2pub mod assistant_settings;
3mod completion_provider;
4mod context_store;
5mod inline_assistant;
6mod model_selector;
7mod prompt_library;
8mod prompts;
9mod search;
10mod slash_command;
11mod streaming_diff;
12mod terminal_inline_assistant;
13
14pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
15use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
16use assistant_slash_command::SlashCommandRegistry;
17use client::{proto, Client};
18use command_palette_hooks::CommandPaletteFilter;
19pub(crate) use completion_provider::*;
20pub(crate) use context_store::*;
21use fs::Fs;
22use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
23use indexed_docs::IndexedDocsRegistry;
24pub(crate) use inline_assistant::*;
25pub(crate) use model_selector::*;
26use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
27use serde::{Deserialize, Serialize};
28use settings::{Settings, SettingsStore};
29use slash_command::{
30 active_command, default_command, diagnostics_command, docs_command, fetch_command,
31 file_command, now_command, project_command, prompt_command, search_command, tabs_command,
32 term_command,
33};
34use std::{
35 fmt::{self, Display},
36 sync::Arc,
37};
38pub(crate) use streaming_diff::*;
39
40actions!(
41 assistant,
42 [
43 Assist,
44 Split,
45 CycleMessageRole,
46 QuoteSelection,
47 InsertIntoEditor,
48 ToggleFocus,
49 ResetKey,
50 InlineAssist,
51 InsertActivePrompt,
52 DeployHistory,
53 DeployPromptLibrary,
54 ApplyEdit,
55 ConfirmCommand,
56 ToggleModelSelector
57 ]
58);
59
60#[derive(
61 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
62)]
63struct MessageId(usize);
64
65#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
66#[serde(rename_all = "lowercase")]
67pub enum Role {
68 User,
69 Assistant,
70 System,
71}
72
73impl Role {
74 pub fn cycle(&mut self) {
75 *self = match self {
76 Role::User => Role::Assistant,
77 Role::Assistant => Role::System,
78 Role::System => Role::User,
79 }
80 }
81}
82
83impl Display for Role {
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
85 match self {
86 Role::User => write!(f, "user"),
87 Role::Assistant => write!(f, "assistant"),
88 Role::System => write!(f, "system"),
89 }
90 }
91}
92
93#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
94pub enum LanguageModel {
95 Cloud(CloudModel),
96 OpenAi(OpenAiModel),
97 Anthropic(AnthropicModel),
98 Ollama(OllamaModel),
99}
100
101impl Default for LanguageModel {
102 fn default() -> Self {
103 LanguageModel::Cloud(CloudModel::default())
104 }
105}
106
107impl LanguageModel {
108 pub fn telemetry_id(&self) -> String {
109 match self {
110 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
111 LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
112 LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
113 LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
114 }
115 }
116
117 pub fn display_name(&self) -> String {
118 match self {
119 LanguageModel::OpenAi(model) => model.display_name().into(),
120 LanguageModel::Anthropic(model) => model.display_name().into(),
121 LanguageModel::Cloud(model) => model.display_name().into(),
122 LanguageModel::Ollama(model) => model.display_name().into(),
123 }
124 }
125
126 pub fn max_token_count(&self) -> usize {
127 match self {
128 LanguageModel::OpenAi(model) => model.max_token_count(),
129 LanguageModel::Anthropic(model) => model.max_token_count(),
130 LanguageModel::Cloud(model) => model.max_token_count(),
131 LanguageModel::Ollama(model) => model.max_token_count(),
132 }
133 }
134
135 pub fn id(&self) -> &str {
136 match self {
137 LanguageModel::OpenAi(model) => model.id(),
138 LanguageModel::Anthropic(model) => model.id(),
139 LanguageModel::Cloud(model) => model.id(),
140 LanguageModel::Ollama(model) => model.id(),
141 }
142 }
143}
144
145#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
146pub struct LanguageModelRequestMessage {
147 pub role: Role,
148 pub content: String,
149}
150
151impl LanguageModelRequestMessage {
152 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
153 proto::LanguageModelRequestMessage {
154 role: match self.role {
155 Role::User => proto::LanguageModelRole::LanguageModelUser,
156 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
157 Role::System => proto::LanguageModelRole::LanguageModelSystem,
158 } as i32,
159 content: self.content.clone(),
160 tool_calls: Vec::new(),
161 tool_call_id: None,
162 }
163 }
164}
165
166#[derive(Debug, Default, Serialize, Deserialize)]
167pub struct LanguageModelRequest {
168 pub model: LanguageModel,
169 pub messages: Vec<LanguageModelRequestMessage>,
170 pub stop: Vec<String>,
171 pub temperature: f32,
172}
173
174impl LanguageModelRequest {
175 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
176 proto::CompleteWithLanguageModel {
177 model: self.model.id().to_string(),
178 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
179 stop: self.stop.clone(),
180 temperature: self.temperature,
181 tool_choice: None,
182 tools: Vec::new(),
183 }
184 }
185
186 /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
187 pub fn preprocess(&mut self) {
188 match &self.model {
189 LanguageModel::OpenAi(_) => {}
190 LanguageModel::Anthropic(_) => {}
191 LanguageModel::Ollama(_) => {}
192 LanguageModel::Cloud(model) => match model {
193 CloudModel::Claude3Opus
194 | CloudModel::Claude3Sonnet
195 | CloudModel::Claude3Haiku
196 | CloudModel::Claude3_5Sonnet => {
197 preprocess_anthropic_request(self);
198 }
199 _ => {}
200 },
201 }
202 }
203}
204
205#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
206pub struct LanguageModelResponseMessage {
207 pub role: Option<Role>,
208 pub content: Option<String>,
209}
210
211#[derive(Deserialize, Debug)]
212pub struct LanguageModelUsage {
213 pub prompt_tokens: u32,
214 pub completion_tokens: u32,
215 pub total_tokens: u32,
216}
217
218#[derive(Deserialize, Debug)]
219pub struct LanguageModelChoiceDelta {
220 pub index: u32,
221 pub delta: LanguageModelResponseMessage,
222 pub finish_reason: Option<String>,
223}
224
225#[derive(Clone, Debug, Serialize, Deserialize)]
226struct MessageMetadata {
227 role: Role,
228 status: MessageStatus,
229}
230
231#[derive(Clone, Debug, Serialize, Deserialize)]
232enum MessageStatus {
233 Pending,
234 Done,
235 Error(SharedString),
236}
237
238/// The state pertaining to the Assistant.
239#[derive(Default)]
240struct Assistant {
241 /// Whether the Assistant is enabled.
242 enabled: bool,
243}
244
245impl Global for Assistant {}
246
247impl Assistant {
248 const NAMESPACE: &'static str = "assistant";
249
250 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
251 if self.enabled == enabled {
252 return;
253 }
254
255 self.enabled = enabled;
256
257 if !enabled {
258 CommandPaletteFilter::update_global(cx, |filter, _cx| {
259 filter.hide_namespace(Self::NAMESPACE);
260 });
261
262 return;
263 }
264
265 CommandPaletteFilter::update_global(cx, |filter, _cx| {
266 filter.show_namespace(Self::NAMESPACE);
267 });
268 }
269}
270
271pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
272 cx.set_global(Assistant::default());
273 AssistantSettings::register(cx);
274
275 cx.spawn(|mut cx| {
276 let client = client.clone();
277 async move {
278 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
279 let semantic_index = SemanticIndex::new(
280 paths::embeddings_dir().join("semantic-index-db.0.mdb"),
281 Arc::new(embedding_provider),
282 &mut cx,
283 )
284 .await?;
285 cx.update(|cx| cx.set_global(semantic_index))
286 }
287 })
288 .detach();
289
290 prompt_library::init(cx);
291 completion_provider::init(client.clone(), cx);
292 assistant_slash_command::init(cx);
293 register_slash_commands(cx);
294 assistant_panel::init(cx);
295 inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
296 terminal_inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
297 IndexedDocsRegistry::init_global(cx);
298
299 CommandPaletteFilter::update_global(cx, |filter, _cx| {
300 filter.hide_namespace(Assistant::NAMESPACE);
301 });
302 Assistant::update_global(cx, |assistant, cx| {
303 let settings = AssistantSettings::get_global(cx);
304
305 assistant.set_enabled(settings.enabled, cx);
306 });
307 cx.observe_global::<SettingsStore>(|cx| {
308 Assistant::update_global(cx, |assistant, cx| {
309 let settings = AssistantSettings::get_global(cx);
310 assistant.set_enabled(settings.enabled, cx);
311 });
312 })
313 .detach();
314}
315
316fn register_slash_commands(cx: &mut AppContext) {
317 let slash_command_registry = SlashCommandRegistry::global(cx);
318 slash_command_registry.register_command(file_command::FileSlashCommand, true);
319 slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
320 slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
321 slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
322 slash_command_registry.register_command(search_command::SearchSlashCommand, true);
323 slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
324 slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
325 slash_command_registry.register_command(term_command::TermSlashCommand, true);
326 slash_command_registry.register_command(now_command::NowSlashCommand, true);
327 slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
328 slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
329 slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
330}
331
332pub fn humanize_token_count(count: usize) -> String {
333 match count {
334 0..=999 => count.to_string(),
335 1000..=9999 => {
336 let thousands = count / 1000;
337 let hundreds = (count % 1000 + 50) / 100;
338 if hundreds == 0 {
339 format!("{}k", thousands)
340 } else if hundreds == 10 {
341 format!("{}k", thousands + 1)
342 } else {
343 format!("{}.{}k", thousands, hundreds)
344 }
345 }
346 _ => format!("{}k", (count + 500) / 1000),
347 }
348}
349
350#[cfg(test)]
351#[ctor::ctor]
352fn init_logger() {
353 if std::env::var("RUST_LOG").is_ok() {
354 env_logger::init();
355 }
356}