1pub mod assistant_panel;
2pub mod assistant_settings;
3mod codegen;
4mod completion_provider;
5mod model_selector;
6mod prompt_library;
7mod prompts;
8mod saved_conversation;
9mod search;
10mod slash_command;
11mod streaming_diff;
12
13pub use assistant_panel::AssistantPanel;
14
15use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OpenAiModel};
16use assistant_slash_command::SlashCommandRegistry;
17use client::{proto, Client};
18use command_palette_hooks::CommandPaletteFilter;
19pub(crate) use completion_provider::*;
20use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
21pub(crate) use model_selector::*;
22pub(crate) use saved_conversation::*;
23use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
24use serde::{Deserialize, Serialize};
25use settings::{Settings, SettingsStore};
26use slash_command::{
27 active_command, default_command, fetch_command, file_command, project_command, prompt_command,
28 rustdoc_command, search_command, tabs_command,
29};
30use std::{
31 fmt::{self, Display},
32 sync::Arc,
33};
34use util::paths::EMBEDDINGS_DIR;
35
36actions!(
37 assistant,
38 [
39 Assist,
40 Split,
41 CycleMessageRole,
42 QuoteSelection,
43 ToggleFocus,
44 ResetKey,
45 InlineAssist,
46 InsertActivePrompt,
47 ToggleHistory,
48 ApplyEdit,
49 ConfirmCommand,
50 ToggleModelSelector
51 ]
52);
53
54#[derive(
55 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
56)]
57struct MessageId(usize);
58
59#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
60#[serde(rename_all = "lowercase")]
61pub enum Role {
62 User,
63 Assistant,
64 System,
65}
66
67impl Role {
68 pub fn cycle(&mut self) {
69 *self = match self {
70 Role::User => Role::Assistant,
71 Role::Assistant => Role::System,
72 Role::System => Role::User,
73 }
74 }
75}
76
77impl Display for Role {
78 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
79 match self {
80 Role::User => write!(f, "user"),
81 Role::Assistant => write!(f, "assistant"),
82 Role::System => write!(f, "system"),
83 }
84 }
85}
86
87#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
88pub enum LanguageModel {
89 Cloud(CloudModel),
90 OpenAi(OpenAiModel),
91 Anthropic(AnthropicModel),
92}
93
94impl Default for LanguageModel {
95 fn default() -> Self {
96 LanguageModel::Cloud(CloudModel::default())
97 }
98}
99
100impl LanguageModel {
101 pub fn telemetry_id(&self) -> String {
102 match self {
103 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
104 LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
105 LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
106 }
107 }
108
109 pub fn display_name(&self) -> String {
110 match self {
111 LanguageModel::OpenAi(model) => model.display_name().into(),
112 LanguageModel::Anthropic(model) => model.display_name().into(),
113 LanguageModel::Cloud(model) => model.display_name().into(),
114 }
115 }
116
117 pub fn max_token_count(&self) -> usize {
118 match self {
119 LanguageModel::OpenAi(model) => model.max_token_count(),
120 LanguageModel::Anthropic(model) => model.max_token_count(),
121 LanguageModel::Cloud(model) => model.max_token_count(),
122 }
123 }
124
125 pub fn id(&self) -> &str {
126 match self {
127 LanguageModel::OpenAi(model) => model.id(),
128 LanguageModel::Anthropic(model) => model.id(),
129 LanguageModel::Cloud(model) => model.id(),
130 }
131 }
132}
133
134#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
135pub struct LanguageModelRequestMessage {
136 pub role: Role,
137 pub content: String,
138}
139
140impl LanguageModelRequestMessage {
141 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
142 proto::LanguageModelRequestMessage {
143 role: match self.role {
144 Role::User => proto::LanguageModelRole::LanguageModelUser,
145 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
146 Role::System => proto::LanguageModelRole::LanguageModelSystem,
147 } as i32,
148 content: self.content.clone(),
149 tool_calls: Vec::new(),
150 tool_call_id: None,
151 }
152 }
153}
154
155#[derive(Debug, Default, Serialize)]
156pub struct LanguageModelRequest {
157 pub model: LanguageModel,
158 pub messages: Vec<LanguageModelRequestMessage>,
159 pub stop: Vec<String>,
160 pub temperature: f32,
161}
162
163impl LanguageModelRequest {
164 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
165 proto::CompleteWithLanguageModel {
166 model: self.model.id().to_string(),
167 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
168 stop: self.stop.clone(),
169 temperature: self.temperature,
170 tool_choice: None,
171 tools: Vec::new(),
172 }
173 }
174
175 /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
176 pub fn preprocess(&mut self) {
177 match &self.model {
178 LanguageModel::OpenAi(_) => {}
179 LanguageModel::Anthropic(_) => {}
180 LanguageModel::Cloud(model) => match model {
181 CloudModel::Claude3Opus | CloudModel::Claude3Sonnet | CloudModel::Claude3Haiku => {
182 preprocess_anthropic_request(self);
183 }
184 _ => {}
185 },
186 }
187 }
188}
189
190#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
191pub struct LanguageModelResponseMessage {
192 pub role: Option<Role>,
193 pub content: Option<String>,
194}
195
196#[derive(Deserialize, Debug)]
197pub struct LanguageModelUsage {
198 pub prompt_tokens: u32,
199 pub completion_tokens: u32,
200 pub total_tokens: u32,
201}
202
203#[derive(Deserialize, Debug)]
204pub struct LanguageModelChoiceDelta {
205 pub index: u32,
206 pub delta: LanguageModelResponseMessage,
207 pub finish_reason: Option<String>,
208}
209
210#[derive(Clone, Debug, Serialize, Deserialize)]
211struct MessageMetadata {
212 role: Role,
213 status: MessageStatus,
214}
215
216#[derive(Clone, Debug, Serialize, Deserialize)]
217enum MessageStatus {
218 Pending,
219 Done,
220 Error(SharedString),
221}
222
223/// The state pertaining to the Assistant.
224#[derive(Default)]
225struct Assistant {
226 /// Whether the Assistant is enabled.
227 enabled: bool,
228}
229
230impl Global for Assistant {}
231
232impl Assistant {
233 const NAMESPACE: &'static str = "assistant";
234
235 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
236 if self.enabled == enabled {
237 return;
238 }
239
240 self.enabled = enabled;
241
242 if !enabled {
243 CommandPaletteFilter::update_global(cx, |filter, _cx| {
244 filter.hide_namespace(Self::NAMESPACE);
245 });
246
247 return;
248 }
249
250 CommandPaletteFilter::update_global(cx, |filter, _cx| {
251 filter.show_namespace(Self::NAMESPACE);
252 });
253 }
254}
255
256pub fn init(client: Arc<Client>, cx: &mut AppContext) {
257 cx.set_global(Assistant::default());
258 AssistantSettings::register(cx);
259
260 cx.spawn(|mut cx| {
261 let client = client.clone();
262 async move {
263 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
264 let semantic_index = SemanticIndex::new(
265 EMBEDDINGS_DIR.join("semantic-index-db.0.mdb"),
266 Arc::new(embedding_provider),
267 &mut cx,
268 )
269 .await?;
270 cx.update(|cx| cx.set_global(semantic_index))
271 }
272 })
273 .detach();
274
275 prompt_library::init(cx);
276 completion_provider::init(client, cx);
277 assistant_slash_command::init(cx);
278 register_slash_commands(cx);
279 assistant_panel::init(cx);
280
281 CommandPaletteFilter::update_global(cx, |filter, _cx| {
282 filter.hide_namespace(Assistant::NAMESPACE);
283 });
284 Assistant::update_global(cx, |assistant, cx| {
285 let settings = AssistantSettings::get_global(cx);
286
287 assistant.set_enabled(settings.enabled, cx);
288 });
289 cx.observe_global::<SettingsStore>(|cx| {
290 Assistant::update_global(cx, |assistant, cx| {
291 let settings = AssistantSettings::get_global(cx);
292 assistant.set_enabled(settings.enabled, cx);
293 });
294 })
295 .detach();
296}
297
298fn register_slash_commands(cx: &mut AppContext) {
299 let slash_command_registry = SlashCommandRegistry::global(cx);
300 slash_command_registry.register_command(file_command::FileSlashCommand, true);
301 slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
302 slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
303 slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
304 slash_command_registry.register_command(search_command::SearchSlashCommand, true);
305 slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
306 slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
307 slash_command_registry.register_command(rustdoc_command::RustdocSlashCommand, false);
308 slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
309}
310
311#[cfg(test)]
312#[ctor::ctor]
313fn init_logger() {
314 if std::env::var("RUST_LOG").is_ok() {
315 env_logger::init();
316 }
317}