1pub mod assistant_panel;
2pub mod assistant_settings;
3mod completion_provider;
4mod context;
5pub mod context_store;
6mod inline_assistant;
7mod model_selector;
8mod prompt_library;
9mod prompts;
10mod slash_command;
11mod streaming_diff;
12mod terminal_inline_assistant;
13
14pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
15use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
16use assistant_slash_command::SlashCommandRegistry;
17use client::{proto, Client};
18use command_palette_hooks::CommandPaletteFilter;
19pub use completion_provider::*;
20pub use context::*;
21pub use context_store::*;
22use fs::Fs;
23use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
24use indexed_docs::IndexedDocsRegistry;
25pub(crate) use inline_assistant::*;
26pub(crate) use model_selector::*;
27use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
28use serde::{Deserialize, Serialize};
29use settings::{Settings, SettingsStore};
30use slash_command::{
31 active_command, default_command, diagnostics_command, docs_command, fetch_command,
32 file_command, now_command, project_command, prompt_command, search_command, symbols_command,
33 tabs_command, term_command,
34};
35use std::{
36 fmt::{self, Display},
37 sync::Arc,
38};
39pub(crate) use streaming_diff::*;
40
41actions!(
42 assistant,
43 [
44 Assist,
45 Split,
46 CycleMessageRole,
47 QuoteSelection,
48 InsertIntoEditor,
49 ToggleFocus,
50 ResetKey,
51 InsertActivePrompt,
52 DeployHistory,
53 DeployPromptLibrary,
54 ConfirmCommand,
55 ToggleModelSelector,
56 DebugEditSteps
57 ]
58);
59
60#[derive(Clone, Default, Deserialize, PartialEq)]
61pub struct InlineAssist {
62 prompt: Option<String>,
63}
64
65impl_actions!(assistant, [InlineAssist]);
66
67#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
68pub struct MessageId(clock::Lamport);
69
70impl MessageId {
71 pub fn as_u64(self) -> u64 {
72 self.0.as_u64()
73 }
74}
75
76#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
77#[serde(rename_all = "lowercase")]
78pub enum Role {
79 User,
80 Assistant,
81 System,
82}
83
84impl Role {
85 pub fn from_proto(role: i32) -> Role {
86 match proto::LanguageModelRole::from_i32(role) {
87 Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
88 Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
89 Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
90 Some(proto::LanguageModelRole::LanguageModelTool) => Role::System,
91 None => Role::User,
92 }
93 }
94
95 pub fn to_proto(&self) -> proto::LanguageModelRole {
96 match self {
97 Role::User => proto::LanguageModelRole::LanguageModelUser,
98 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
99 Role::System => proto::LanguageModelRole::LanguageModelSystem,
100 }
101 }
102
103 pub fn cycle(self) -> Role {
104 match self {
105 Role::User => Role::Assistant,
106 Role::Assistant => Role::System,
107 Role::System => Role::User,
108 }
109 }
110}
111
112impl Display for Role {
113 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
114 match self {
115 Role::User => write!(f, "user"),
116 Role::Assistant => write!(f, "assistant"),
117 Role::System => write!(f, "system"),
118 }
119 }
120}
121
122#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
123pub enum LanguageModel {
124 Cloud(CloudModel),
125 OpenAi(OpenAiModel),
126 Anthropic(AnthropicModel),
127 Ollama(OllamaModel),
128}
129
130impl Default for LanguageModel {
131 fn default() -> Self {
132 LanguageModel::Cloud(CloudModel::default())
133 }
134}
135
136impl LanguageModel {
137 pub fn telemetry_id(&self) -> String {
138 match self {
139 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
140 LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
141 LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
142 LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
143 }
144 }
145
146 pub fn display_name(&self) -> String {
147 match self {
148 LanguageModel::OpenAi(model) => model.display_name().into(),
149 LanguageModel::Anthropic(model) => model.display_name().into(),
150 LanguageModel::Cloud(model) => model.display_name().into(),
151 LanguageModel::Ollama(model) => model.display_name().into(),
152 }
153 }
154
155 pub fn max_token_count(&self) -> usize {
156 match self {
157 LanguageModel::OpenAi(model) => model.max_token_count(),
158 LanguageModel::Anthropic(model) => model.max_token_count(),
159 LanguageModel::Cloud(model) => model.max_token_count(),
160 LanguageModel::Ollama(model) => model.max_token_count(),
161 }
162 }
163
164 pub fn id(&self) -> &str {
165 match self {
166 LanguageModel::OpenAi(model) => model.id(),
167 LanguageModel::Anthropic(model) => model.id(),
168 LanguageModel::Cloud(model) => model.id(),
169 LanguageModel::Ollama(model) => model.id(),
170 }
171 }
172}
173
174#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
175pub struct LanguageModelRequestMessage {
176 pub role: Role,
177 pub content: String,
178}
179
180impl LanguageModelRequestMessage {
181 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
182 proto::LanguageModelRequestMessage {
183 role: self.role.to_proto() as i32,
184 content: self.content.clone(),
185 tool_calls: Vec::new(),
186 tool_call_id: None,
187 }
188 }
189}
190
191#[derive(Debug, Default, Serialize, Deserialize)]
192pub struct LanguageModelRequest {
193 pub model: LanguageModel,
194 pub messages: Vec<LanguageModelRequestMessage>,
195 pub stop: Vec<String>,
196 pub temperature: f32,
197}
198
199impl LanguageModelRequest {
200 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
201 proto::CompleteWithLanguageModel {
202 model: self.model.id().to_string(),
203 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
204 stop: self.stop.clone(),
205 temperature: self.temperature,
206 tool_choice: None,
207 tools: Vec::new(),
208 }
209 }
210
211 /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
212 pub fn preprocess(&mut self) {
213 match &self.model {
214 LanguageModel::OpenAi(_) => {}
215 LanguageModel::Anthropic(_) => {}
216 LanguageModel::Ollama(_) => {}
217 LanguageModel::Cloud(model) => match model {
218 CloudModel::Claude3Opus
219 | CloudModel::Claude3Sonnet
220 | CloudModel::Claude3Haiku
221 | CloudModel::Claude3_5Sonnet => {
222 preprocess_anthropic_request(self);
223 }
224 _ => {}
225 },
226 }
227 }
228}
229
230#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
231pub struct LanguageModelResponseMessage {
232 pub role: Option<Role>,
233 pub content: Option<String>,
234}
235
236#[derive(Deserialize, Debug)]
237pub struct LanguageModelUsage {
238 pub prompt_tokens: u32,
239 pub completion_tokens: u32,
240 pub total_tokens: u32,
241}
242
243#[derive(Deserialize, Debug)]
244pub struct LanguageModelChoiceDelta {
245 pub index: u32,
246 pub delta: LanguageModelResponseMessage,
247 pub finish_reason: Option<String>,
248}
249
250#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
251pub enum MessageStatus {
252 Pending,
253 Done,
254 Error(SharedString),
255}
256
257impl MessageStatus {
258 pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
259 match status.variant {
260 Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
261 Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
262 Some(proto::context_message_status::Variant::Error(error)) => {
263 MessageStatus::Error(error.message.into())
264 }
265 None => MessageStatus::Pending,
266 }
267 }
268
269 pub fn to_proto(&self) -> proto::ContextMessageStatus {
270 match self {
271 MessageStatus::Pending => proto::ContextMessageStatus {
272 variant: Some(proto::context_message_status::Variant::Pending(
273 proto::context_message_status::Pending {},
274 )),
275 },
276 MessageStatus::Done => proto::ContextMessageStatus {
277 variant: Some(proto::context_message_status::Variant::Done(
278 proto::context_message_status::Done {},
279 )),
280 },
281 MessageStatus::Error(message) => proto::ContextMessageStatus {
282 variant: Some(proto::context_message_status::Variant::Error(
283 proto::context_message_status::Error {
284 message: message.to_string(),
285 },
286 )),
287 },
288 }
289 }
290}
291
292/// The state pertaining to the Assistant.
293#[derive(Default)]
294struct Assistant {
295 /// Whether the Assistant is enabled.
296 enabled: bool,
297}
298
299impl Global for Assistant {}
300
301impl Assistant {
302 const NAMESPACE: &'static str = "assistant";
303
304 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
305 if self.enabled == enabled {
306 return;
307 }
308
309 self.enabled = enabled;
310
311 if !enabled {
312 CommandPaletteFilter::update_global(cx, |filter, _cx| {
313 filter.hide_namespace(Self::NAMESPACE);
314 });
315
316 return;
317 }
318
319 CommandPaletteFilter::update_global(cx, |filter, _cx| {
320 filter.show_namespace(Self::NAMESPACE);
321 });
322 }
323}
324
325pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
326 cx.set_global(Assistant::default());
327 AssistantSettings::register(cx);
328
329 cx.spawn(|mut cx| {
330 let client = client.clone();
331 async move {
332 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
333 let semantic_index = SemanticIndex::new(
334 paths::embeddings_dir().join("semantic-index-db.0.mdb"),
335 Arc::new(embedding_provider),
336 &mut cx,
337 )
338 .await?;
339 cx.update(|cx| cx.set_global(semantic_index))
340 }
341 })
342 .detach();
343
344 context_store::init(&client);
345 prompt_library::init(cx);
346 completion_provider::init(client.clone(), cx);
347 assistant_slash_command::init(cx);
348 register_slash_commands(cx);
349 assistant_panel::init(cx);
350 inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
351 terminal_inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
352 IndexedDocsRegistry::init_global(cx);
353
354 CommandPaletteFilter::update_global(cx, |filter, _cx| {
355 filter.hide_namespace(Assistant::NAMESPACE);
356 });
357 Assistant::update_global(cx, |assistant, cx| {
358 let settings = AssistantSettings::get_global(cx);
359
360 assistant.set_enabled(settings.enabled, cx);
361 });
362 cx.observe_global::<SettingsStore>(|cx| {
363 Assistant::update_global(cx, |assistant, cx| {
364 let settings = AssistantSettings::get_global(cx);
365 assistant.set_enabled(settings.enabled, cx);
366 });
367 })
368 .detach();
369}
370
371fn register_slash_commands(cx: &mut AppContext) {
372 let slash_command_registry = SlashCommandRegistry::global(cx);
373 slash_command_registry.register_command(file_command::FileSlashCommand, true);
374 slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
375 slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
376 slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
377 slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
378 slash_command_registry.register_command(search_command::SearchSlashCommand, true);
379 slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
380 slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
381 slash_command_registry.register_command(term_command::TermSlashCommand, true);
382 slash_command_registry.register_command(now_command::NowSlashCommand, true);
383 slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
384 slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
385 slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
386}
387
388pub fn humanize_token_count(count: usize) -> String {
389 match count {
390 0..=999 => count.to_string(),
391 1000..=9999 => {
392 let thousands = count / 1000;
393 let hundreds = (count % 1000 + 50) / 100;
394 if hundreds == 0 {
395 format!("{}k", thousands)
396 } else if hundreds == 10 {
397 format!("{}k", thousands + 1)
398 } else {
399 format!("{}.{}k", thousands, hundreds)
400 }
401 }
402 _ => format!("{}k", (count + 500) / 1000),
403 }
404}
405
406#[cfg(test)]
407#[ctor::ctor]
408fn init_logger() {
409 if std::env::var("RUST_LOG").is_ok() {
410 env_logger::init();
411 }
412}