1pub mod assistant_panel;
2pub mod assistant_settings;
3mod completion_provider;
4mod context_store;
5mod inline_assistant;
6mod model_selector;
7mod prompt_library;
8mod prompts;
9mod search;
10mod slash_command;
11mod streaming_diff;
12mod terminal_inline_assistant;
13
14pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
15use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
16use assistant_slash_command::SlashCommandRegistry;
17use client::{proto, Client};
18use command_palette_hooks::CommandPaletteFilter;
19pub(crate) use completion_provider::*;
20pub(crate) use context_store::*;
21use fs::Fs;
22use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
23use indexed_docs::IndexedDocsRegistry;
24pub(crate) use inline_assistant::*;
25pub(crate) use model_selector::*;
26use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
27use serde::{Deserialize, Serialize};
28use settings::{Settings, SettingsStore};
29use slash_command::{
30 active_command, default_command, diagnostics_command, fetch_command, file_command, now_command,
31 project_command, prompt_command, rustdoc_command, search_command, tabs_command, term_command,
32};
33use std::{
34 fmt::{self, Display},
35 sync::Arc,
36};
37pub(crate) use streaming_diff::*;
38
39actions!(
40 assistant,
41 [
42 Assist,
43 Split,
44 CycleMessageRole,
45 QuoteSelection,
46 InsertIntoEditor,
47 ToggleFocus,
48 ResetKey,
49 InlineAssist,
50 InsertActivePrompt,
51 ToggleHistory,
52 ApplyEdit,
53 ConfirmCommand,
54 ToggleModelSelector
55 ]
56);
57
58#[derive(
59 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
60)]
61struct MessageId(usize);
62
63#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
64#[serde(rename_all = "lowercase")]
65pub enum Role {
66 User,
67 Assistant,
68 System,
69}
70
71impl Role {
72 pub fn cycle(&mut self) {
73 *self = match self {
74 Role::User => Role::Assistant,
75 Role::Assistant => Role::System,
76 Role::System => Role::User,
77 }
78 }
79}
80
81impl Display for Role {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
83 match self {
84 Role::User => write!(f, "user"),
85 Role::Assistant => write!(f, "assistant"),
86 Role::System => write!(f, "system"),
87 }
88 }
89}
90
91#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
92pub enum LanguageModel {
93 Cloud(CloudModel),
94 OpenAi(OpenAiModel),
95 Anthropic(AnthropicModel),
96 Ollama(OllamaModel),
97}
98
99impl Default for LanguageModel {
100 fn default() -> Self {
101 LanguageModel::Cloud(CloudModel::default())
102 }
103}
104
105impl LanguageModel {
106 pub fn telemetry_id(&self) -> String {
107 match self {
108 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
109 LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
110 LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
111 LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
112 }
113 }
114
115 pub fn display_name(&self) -> String {
116 match self {
117 LanguageModel::OpenAi(model) => model.display_name().into(),
118 LanguageModel::Anthropic(model) => model.display_name().into(),
119 LanguageModel::Cloud(model) => model.display_name().into(),
120 LanguageModel::Ollama(model) => model.display_name().into(),
121 }
122 }
123
124 pub fn max_token_count(&self) -> usize {
125 match self {
126 LanguageModel::OpenAi(model) => model.max_token_count(),
127 LanguageModel::Anthropic(model) => model.max_token_count(),
128 LanguageModel::Cloud(model) => model.max_token_count(),
129 LanguageModel::Ollama(model) => model.max_token_count(),
130 }
131 }
132
133 pub fn id(&self) -> &str {
134 match self {
135 LanguageModel::OpenAi(model) => model.id(),
136 LanguageModel::Anthropic(model) => model.id(),
137 LanguageModel::Cloud(model) => model.id(),
138 LanguageModel::Ollama(model) => model.id(),
139 }
140 }
141}
142
143#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
144pub struct LanguageModelRequestMessage {
145 pub role: Role,
146 pub content: String,
147}
148
149impl LanguageModelRequestMessage {
150 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
151 proto::LanguageModelRequestMessage {
152 role: match self.role {
153 Role::User => proto::LanguageModelRole::LanguageModelUser,
154 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
155 Role::System => proto::LanguageModelRole::LanguageModelSystem,
156 } as i32,
157 content: self.content.clone(),
158 tool_calls: Vec::new(),
159 tool_call_id: None,
160 }
161 }
162}
163
164#[derive(Debug, Default, Serialize)]
165pub struct LanguageModelRequest {
166 pub model: LanguageModel,
167 pub messages: Vec<LanguageModelRequestMessage>,
168 pub stop: Vec<String>,
169 pub temperature: f32,
170}
171
172impl LanguageModelRequest {
173 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
174 proto::CompleteWithLanguageModel {
175 model: self.model.id().to_string(),
176 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
177 stop: self.stop.clone(),
178 temperature: self.temperature,
179 tool_choice: None,
180 tools: Vec::new(),
181 }
182 }
183
184 /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
185 pub fn preprocess(&mut self) {
186 match &self.model {
187 LanguageModel::OpenAi(_) => {}
188 LanguageModel::Anthropic(_) => {}
189 LanguageModel::Ollama(_) => {}
190 LanguageModel::Cloud(model) => match model {
191 CloudModel::Claude3Opus
192 | CloudModel::Claude3Sonnet
193 | CloudModel::Claude3Haiku
194 | CloudModel::Claude3_5Sonnet => {
195 preprocess_anthropic_request(self);
196 }
197 _ => {}
198 },
199 }
200 }
201}
202
203#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
204pub struct LanguageModelResponseMessage {
205 pub role: Option<Role>,
206 pub content: Option<String>,
207}
208
209#[derive(Deserialize, Debug)]
210pub struct LanguageModelUsage {
211 pub prompt_tokens: u32,
212 pub completion_tokens: u32,
213 pub total_tokens: u32,
214}
215
216#[derive(Deserialize, Debug)]
217pub struct LanguageModelChoiceDelta {
218 pub index: u32,
219 pub delta: LanguageModelResponseMessage,
220 pub finish_reason: Option<String>,
221}
222
223#[derive(Clone, Debug, Serialize, Deserialize)]
224struct MessageMetadata {
225 role: Role,
226 status: MessageStatus,
227}
228
229#[derive(Clone, Debug, Serialize, Deserialize)]
230enum MessageStatus {
231 Pending,
232 Done,
233 Error(SharedString),
234}
235
236/// The state pertaining to the Assistant.
237#[derive(Default)]
238struct Assistant {
239 /// Whether the Assistant is enabled.
240 enabled: bool,
241}
242
243impl Global for Assistant {}
244
245impl Assistant {
246 const NAMESPACE: &'static str = "assistant";
247
248 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
249 if self.enabled == enabled {
250 return;
251 }
252
253 self.enabled = enabled;
254
255 if !enabled {
256 CommandPaletteFilter::update_global(cx, |filter, _cx| {
257 filter.hide_namespace(Self::NAMESPACE);
258 });
259
260 return;
261 }
262
263 CommandPaletteFilter::update_global(cx, |filter, _cx| {
264 filter.show_namespace(Self::NAMESPACE);
265 });
266 }
267}
268
269pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
270 cx.set_global(Assistant::default());
271 AssistantSettings::register(cx);
272
273 cx.spawn(|mut cx| {
274 let client = client.clone();
275 async move {
276 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
277 let semantic_index = SemanticIndex::new(
278 paths::embeddings_dir().join("semantic-index-db.0.mdb"),
279 Arc::new(embedding_provider),
280 &mut cx,
281 )
282 .await?;
283 cx.update(|cx| cx.set_global(semantic_index))
284 }
285 })
286 .detach();
287
288 prompt_library::init(cx);
289 completion_provider::init(client.clone(), cx);
290 assistant_slash_command::init(cx);
291 register_slash_commands(cx);
292 assistant_panel::init(cx);
293 inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
294 terminal_inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
295 IndexedDocsRegistry::init_global(cx);
296
297 CommandPaletteFilter::update_global(cx, |filter, _cx| {
298 filter.hide_namespace(Assistant::NAMESPACE);
299 });
300 Assistant::update_global(cx, |assistant, cx| {
301 let settings = AssistantSettings::get_global(cx);
302
303 assistant.set_enabled(settings.enabled, cx);
304 });
305 cx.observe_global::<SettingsStore>(|cx| {
306 Assistant::update_global(cx, |assistant, cx| {
307 let settings = AssistantSettings::get_global(cx);
308 assistant.set_enabled(settings.enabled, cx);
309 });
310 })
311 .detach();
312}
313
314fn register_slash_commands(cx: &mut AppContext) {
315 let slash_command_registry = SlashCommandRegistry::global(cx);
316 slash_command_registry.register_command(file_command::FileSlashCommand, true);
317 slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
318 slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
319 slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
320 slash_command_registry.register_command(search_command::SearchSlashCommand, true);
321 slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
322 slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
323 slash_command_registry.register_command(term_command::TermSlashCommand, true);
324 slash_command_registry.register_command(now_command::NowSlashCommand, true);
325 slash_command_registry.register_command(diagnostics_command::DiagnosticsCommand, true);
326 slash_command_registry.register_command(rustdoc_command::RustdocSlashCommand, false);
327 slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
328}
329
330pub fn humanize_token_count(count: usize) -> String {
331 match count {
332 0..=999 => count.to_string(),
333 1000..=9999 => {
334 let thousands = count / 1000;
335 let hundreds = (count % 1000 + 50) / 100;
336 if hundreds == 0 {
337 format!("{}k", thousands)
338 } else if hundreds == 10 {
339 format!("{}k", thousands + 1)
340 } else {
341 format!("{}.{}k", thousands, hundreds)
342 }
343 }
344 _ => format!("{}k", (count + 500) / 1000),
345 }
346}
347
348#[cfg(test)]
349#[ctor::ctor]
350fn init_logger() {
351 if std::env::var("RUST_LOG").is_ok() {
352 env_logger::init();
353 }
354}