1mod ambient_context;
2pub mod assistant_panel;
3pub mod assistant_settings;
4mod codegen;
5mod completion_provider;
6mod prompt_library;
7mod prompts;
8mod saved_conversation;
9mod search;
10mod streaming_diff;
11
12use ambient_context::AmbientContextSnapshot;
13pub use assistant_panel::AssistantPanel;
14use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
15use client::{proto, Client};
16use command_palette_hooks::CommandPaletteFilter;
17pub(crate) use completion_provider::*;
18use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
19pub(crate) use saved_conversation::*;
20use serde::{Deserialize, Serialize};
21use settings::{Settings, SettingsStore};
22use std::{
23 fmt::{self, Display},
24 sync::Arc,
25};
26
27actions!(
28 assistant,
29 [
30 Assist,
31 Split,
32 CycleMessageRole,
33 QuoteSelection,
34 ToggleFocus,
35 ResetKey,
36 InlineAssist,
37 InsertActivePrompt,
38 ToggleIncludeConversation,
39 ToggleHistory,
40 ApplyEdit
41 ]
42);
43
44#[derive(
45 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
46)]
47struct MessageId(usize);
48
49#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
50#[serde(rename_all = "lowercase")]
51pub enum Role {
52 User,
53 Assistant,
54 System,
55}
56
57impl Role {
58 pub fn cycle(&mut self) {
59 *self = match self {
60 Role::User => Role::Assistant,
61 Role::Assistant => Role::System,
62 Role::System => Role::User,
63 }
64 }
65}
66
67impl Display for Role {
68 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
69 match self {
70 Role::User => write!(f, "user"),
71 Role::Assistant => write!(f, "assistant"),
72 Role::System => write!(f, "system"),
73 }
74 }
75}
76
77#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
78pub enum LanguageModel {
79 ZedDotDev(ZedDotDevModel),
80 OpenAi(OpenAiModel),
81 Anthropic(AnthropicModel),
82}
83
84impl Default for LanguageModel {
85 fn default() -> Self {
86 LanguageModel::ZedDotDev(ZedDotDevModel::default())
87 }
88}
89
90impl LanguageModel {
91 pub fn telemetry_id(&self) -> String {
92 match self {
93 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
94 LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
95 LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
96 }
97 }
98
99 pub fn display_name(&self) -> String {
100 match self {
101 LanguageModel::OpenAi(model) => model.display_name().into(),
102 LanguageModel::Anthropic(model) => model.display_name().into(),
103 LanguageModel::ZedDotDev(model) => model.display_name().into(),
104 }
105 }
106
107 pub fn max_token_count(&self) -> usize {
108 match self {
109 LanguageModel::OpenAi(model) => model.max_token_count(),
110 LanguageModel::Anthropic(model) => model.max_token_count(),
111 LanguageModel::ZedDotDev(model) => model.max_token_count(),
112 }
113 }
114
115 pub fn id(&self) -> &str {
116 match self {
117 LanguageModel::OpenAi(model) => model.id(),
118 LanguageModel::Anthropic(model) => model.id(),
119 LanguageModel::ZedDotDev(model) => model.id(),
120 }
121 }
122}
123
124#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
125pub struct LanguageModelRequestMessage {
126 pub role: Role,
127 pub content: String,
128}
129
130impl LanguageModelRequestMessage {
131 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
132 proto::LanguageModelRequestMessage {
133 role: match self.role {
134 Role::User => proto::LanguageModelRole::LanguageModelUser,
135 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
136 Role::System => proto::LanguageModelRole::LanguageModelSystem,
137 } as i32,
138 content: self.content.clone(),
139 tool_calls: Vec::new(),
140 tool_call_id: None,
141 }
142 }
143}
144
145#[derive(Debug, Default, Serialize)]
146pub struct LanguageModelRequest {
147 pub model: LanguageModel,
148 pub messages: Vec<LanguageModelRequestMessage>,
149 pub stop: Vec<String>,
150 pub temperature: f32,
151}
152
153impl LanguageModelRequest {
154 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
155 proto::CompleteWithLanguageModel {
156 model: self.model.id().to_string(),
157 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
158 stop: self.stop.clone(),
159 temperature: self.temperature,
160 tool_choice: None,
161 tools: Vec::new(),
162 }
163 }
164}
165
166#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
167pub struct LanguageModelResponseMessage {
168 pub role: Option<Role>,
169 pub content: Option<String>,
170}
171
172#[derive(Deserialize, Debug)]
173pub struct LanguageModelUsage {
174 pub prompt_tokens: u32,
175 pub completion_tokens: u32,
176 pub total_tokens: u32,
177}
178
179#[derive(Deserialize, Debug)]
180pub struct LanguageModelChoiceDelta {
181 pub index: u32,
182 pub delta: LanguageModelResponseMessage,
183 pub finish_reason: Option<String>,
184}
185
186#[derive(Clone, Debug, Serialize, Deserialize)]
187struct MessageMetadata {
188 role: Role,
189 status: MessageStatus,
190 // todo!("delete this")
191 #[serde(skip)]
192 ambient_context: AmbientContextSnapshot,
193}
194
195#[derive(Clone, Debug, Serialize, Deserialize)]
196enum MessageStatus {
197 Pending,
198 Done,
199 Error(SharedString),
200}
201
202/// The state pertaining to the Assistant.
203#[derive(Default)]
204struct Assistant {
205 /// Whether the Assistant is enabled.
206 enabled: bool,
207}
208
209impl Global for Assistant {}
210
211impl Assistant {
212 const NAMESPACE: &'static str = "assistant";
213
214 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
215 if self.enabled == enabled {
216 return;
217 }
218
219 self.enabled = enabled;
220
221 if !enabled {
222 CommandPaletteFilter::update_global(cx, |filter, _cx| {
223 filter.hide_namespace(Self::NAMESPACE);
224 });
225
226 return;
227 }
228
229 CommandPaletteFilter::update_global(cx, |filter, _cx| {
230 filter.show_namespace(Self::NAMESPACE);
231 });
232 }
233}
234
235pub fn init(client: Arc<Client>, cx: &mut AppContext) {
236 cx.set_global(Assistant::default());
237 AssistantSettings::register(cx);
238 completion_provider::init(client, cx);
239 assistant_panel::init(cx);
240
241 CommandPaletteFilter::update_global(cx, |filter, _cx| {
242 filter.hide_namespace(Assistant::NAMESPACE);
243 });
244 Assistant::update_global(cx, |assistant, cx| {
245 let settings = AssistantSettings::get_global(cx);
246
247 assistant.set_enabled(settings.enabled, cx);
248 });
249 cx.observe_global::<SettingsStore>(|cx| {
250 Assistant::update_global(cx, |assistant, cx| {
251 let settings = AssistantSettings::get_global(cx);
252
253 assistant.set_enabled(settings.enabled, cx);
254 });
255 })
256 .detach();
257}
258
259#[cfg(test)]
260#[ctor::ctor]
261fn init_logger() {
262 if std::env::var("RUST_LOG").is_ok() {
263 env_logger::init();
264 }
265}