1use anyhow::{anyhow, bail, Result};
2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
3use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
4use http_client::HttpClient;
5use language_model::LanguageModelCompletionEvent;
6use language_model::{
7 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
8 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
9 LanguageModelRequest, RateLimiter, Role,
10};
11use ollama::{
12 get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
13 ChatResponseDelta, KeepAlive, OllamaToolCall,
14};
15use schemars::JsonSchema;
16use serde::{Deserialize, Serialize};
17use settings::{Settings, SettingsStore};
18use std::{collections::BTreeMap, sync::Arc};
19use ui::{prelude::*, ButtonLike, Indicator};
20use util::ResultExt;
21
22use crate::AllLanguageModelSettings;
23
24const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
25const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
26const OLLAMA_SITE: &str = "https://ollama.com/";
27
28const PROVIDER_ID: &str = "ollama";
29const PROVIDER_NAME: &str = "Ollama";
30
31#[derive(Default, Debug, Clone, PartialEq)]
32pub struct OllamaSettings {
33 pub api_url: String,
34 pub available_models: Vec<AvailableModel>,
35}
36
37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
38pub struct AvailableModel {
39 /// The model name in the Ollama API (e.g. "llama3.2:latest")
40 pub name: String,
41 /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
42 pub display_name: Option<String>,
43 /// The Context Length parameter to the model (aka num_ctx or n_ctx)
44 pub max_tokens: usize,
45 /// The number of seconds to keep the connection open after the last request
46 pub keep_alive: Option<KeepAlive>,
47}
48
49pub struct OllamaLanguageModelProvider {
50 http_client: Arc<dyn HttpClient>,
51 state: gpui::Model<State>,
52}
53
54pub struct State {
55 http_client: Arc<dyn HttpClient>,
56 available_models: Vec<ollama::Model>,
57 fetch_model_task: Option<Task<Result<()>>>,
58 _subscription: Subscription,
59}
60
61impl State {
62 fn is_authenticated(&self) -> bool {
63 !self.available_models.is_empty()
64 }
65
66 fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
67 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
68 let http_client = self.http_client.clone();
69 let api_url = settings.api_url.clone();
70
71 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
72 cx.spawn(|this, mut cx| async move {
73 let models = get_models(http_client.as_ref(), &api_url, None).await?;
74
75 let mut models: Vec<ollama::Model> = models
76 .into_iter()
77 // Since there is no metadata from the Ollama API
78 // indicating which models are embedding models,
79 // simply filter out models with "-embed" in their name
80 .filter(|model| !model.name.contains("-embed"))
81 .map(|model| ollama::Model::new(&model.name, None, None))
82 .collect();
83
84 models.sort_by(|a, b| a.name.cmp(&b.name));
85
86 this.update(&mut cx, |this, cx| {
87 this.available_models = models;
88 cx.notify();
89 })
90 })
91 }
92
93 fn restart_fetch_models_task(&mut self, cx: &mut ModelContext<Self>) {
94 let task = self.fetch_models(cx);
95 self.fetch_model_task.replace(task);
96 }
97
98 fn authenticate(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
99 if self.is_authenticated() {
100 Task::ready(Ok(()))
101 } else {
102 self.fetch_models(cx)
103 }
104 }
105}
106
107impl OllamaLanguageModelProvider {
108 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
109 let this = Self {
110 http_client: http_client.clone(),
111 state: cx.new_model(|cx| {
112 let subscription = cx.observe_global::<SettingsStore>({
113 let mut settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
114 move |this: &mut State, cx| {
115 let new_settings = &AllLanguageModelSettings::get_global(cx).ollama;
116 if &settings != new_settings {
117 settings = new_settings.clone();
118 this.restart_fetch_models_task(cx);
119 cx.notify();
120 }
121 }
122 });
123
124 State {
125 http_client,
126 available_models: Default::default(),
127 fetch_model_task: None,
128 _subscription: subscription,
129 }
130 }),
131 };
132 this.state
133 .update(cx, |state, cx| state.restart_fetch_models_task(cx));
134 this
135 }
136}
137
138impl LanguageModelProviderState for OllamaLanguageModelProvider {
139 type ObservableEntity = State;
140
141 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
142 Some(self.state.clone())
143 }
144}
145
146impl LanguageModelProvider for OllamaLanguageModelProvider {
147 fn id(&self) -> LanguageModelProviderId {
148 LanguageModelProviderId(PROVIDER_ID.into())
149 }
150
151 fn name(&self) -> LanguageModelProviderName {
152 LanguageModelProviderName(PROVIDER_NAME.into())
153 }
154
155 fn icon(&self) -> IconName {
156 IconName::AiOllama
157 }
158
159 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
160 let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();
161
162 // Add models from the Ollama API
163 for model in self.state.read(cx).available_models.iter() {
164 models.insert(model.name.clone(), model.clone());
165 }
166
167 // Override with available models from settings
168 for model in AllLanguageModelSettings::get_global(cx)
169 .ollama
170 .available_models
171 .iter()
172 {
173 models.insert(
174 model.name.clone(),
175 ollama::Model {
176 name: model.name.clone(),
177 display_name: model.display_name.clone(),
178 max_tokens: model.max_tokens,
179 keep_alive: model.keep_alive.clone(),
180 },
181 );
182 }
183
184 models
185 .into_values()
186 .map(|model| {
187 Arc::new(OllamaLanguageModel {
188 id: LanguageModelId::from(model.name.clone()),
189 model: model.clone(),
190 http_client: self.http_client.clone(),
191 request_limiter: RateLimiter::new(4),
192 }) as Arc<dyn LanguageModel>
193 })
194 .collect()
195 }
196
197 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
198 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
199 let http_client = self.http_client.clone();
200 let api_url = settings.api_url.clone();
201 let id = model.id().0.to_string();
202 cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
203 .detach_and_log_err(cx);
204 }
205
206 fn is_authenticated(&self, cx: &AppContext) -> bool {
207 self.state.read(cx).is_authenticated()
208 }
209
210 fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
211 self.state.update(cx, |state, cx| state.authenticate(cx))
212 }
213
214 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
215 let state = self.state.clone();
216 cx.new_view(|cx| ConfigurationView::new(state, cx)).into()
217 }
218
219 fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
220 self.state.update(cx, |state, cx| state.fetch_models(cx))
221 }
222}
223
224pub struct OllamaLanguageModel {
225 id: LanguageModelId,
226 model: ollama::Model,
227 http_client: Arc<dyn HttpClient>,
228 request_limiter: RateLimiter,
229}
230
231impl OllamaLanguageModel {
232 fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
233 ChatRequest {
234 model: self.model.name.clone(),
235 messages: request
236 .messages
237 .into_iter()
238 .map(|msg| match msg.role {
239 Role::User => ChatMessage::User {
240 content: msg.string_contents(),
241 },
242 Role::Assistant => ChatMessage::Assistant {
243 content: msg.string_contents(),
244 tool_calls: None,
245 },
246 Role::System => ChatMessage::System {
247 content: msg.string_contents(),
248 },
249 })
250 .collect(),
251 keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
252 stream: true,
253 options: Some(ChatOptions {
254 num_ctx: Some(self.model.max_tokens),
255 stop: Some(request.stop),
256 temperature: request.temperature.or(Some(1.0)),
257 ..Default::default()
258 }),
259 tools: vec![],
260 }
261 }
262 fn request_completion(
263 &self,
264 request: ChatRequest,
265 cx: &AsyncAppContext,
266 ) -> BoxFuture<'static, Result<ChatResponseDelta>> {
267 let http_client = self.http_client.clone();
268
269 let Ok(api_url) = cx.update(|cx| {
270 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
271 settings.api_url.clone()
272 }) else {
273 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
274 };
275
276 async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
277 }
278}
279
280impl LanguageModel for OllamaLanguageModel {
281 fn id(&self) -> LanguageModelId {
282 self.id.clone()
283 }
284
285 fn name(&self) -> LanguageModelName {
286 LanguageModelName::from(self.model.display_name().to_string())
287 }
288
289 fn provider_id(&self) -> LanguageModelProviderId {
290 LanguageModelProviderId(PROVIDER_ID.into())
291 }
292
293 fn provider_name(&self) -> LanguageModelProviderName {
294 LanguageModelProviderName(PROVIDER_NAME.into())
295 }
296
297 fn telemetry_id(&self) -> String {
298 format!("ollama/{}", self.model.id())
299 }
300
301 fn max_token_count(&self) -> usize {
302 self.model.max_token_count()
303 }
304
305 fn count_tokens(
306 &self,
307 request: LanguageModelRequest,
308 _cx: &AppContext,
309 ) -> BoxFuture<'static, Result<usize>> {
310 // There is no endpoint for this _yet_ in Ollama
311 // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
312 let token_count = request
313 .messages
314 .iter()
315 .map(|msg| msg.string_contents().chars().count())
316 .sum::<usize>()
317 / 4;
318
319 async move { Ok(token_count) }.boxed()
320 }
321
322 fn stream_completion(
323 &self,
324 request: LanguageModelRequest,
325 cx: &AsyncAppContext,
326 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
327 let request = self.to_ollama_request(request);
328
329 let http_client = self.http_client.clone();
330 let Ok(api_url) = cx.update(|cx| {
331 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
332 settings.api_url.clone()
333 }) else {
334 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
335 };
336
337 let future = self.request_limiter.stream(async move {
338 let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
339 let stream = response
340 .filter_map(|response| async move {
341 match response {
342 Ok(delta) => {
343 let content = match delta.message {
344 ChatMessage::User { content } => content,
345 ChatMessage::Assistant { content, .. } => content,
346 ChatMessage::System { content } => content,
347 };
348 Some(Ok(content))
349 }
350 Err(error) => Some(Err(error)),
351 }
352 })
353 .boxed();
354 Ok(stream)
355 });
356
357 async move {
358 Ok(future
359 .await?
360 .map(|result| result.map(LanguageModelCompletionEvent::Text))
361 .boxed())
362 }
363 .boxed()
364 }
365
366 fn use_any_tool(
367 &self,
368 request: LanguageModelRequest,
369 tool_name: String,
370 tool_description: String,
371 schema: serde_json::Value,
372 cx: &AsyncAppContext,
373 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
374 use ollama::{OllamaFunctionTool, OllamaTool};
375 let function = OllamaFunctionTool {
376 name: tool_name.clone(),
377 description: Some(tool_description),
378 parameters: Some(schema),
379 };
380 let tools = vec![OllamaTool::Function { function }];
381 let request = self.to_ollama_request(request).with_tools(tools);
382 let response = self.request_completion(request, cx);
383 self.request_limiter
384 .run(async move {
385 let response = response.await?;
386 let ChatMessage::Assistant { tool_calls, .. } = response.message else {
387 bail!("message does not have an assistant role");
388 };
389 if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
390 for call in tool_calls {
391 let OllamaToolCall::Function(function) = call;
392 if function.name == tool_name {
393 return Ok(futures::stream::once(async move {
394 Ok(function.arguments.to_string())
395 })
396 .boxed());
397 }
398 }
399 } else {
400 bail!("assistant message does not have any tool calls");
401 };
402
403 bail!("tool not used")
404 })
405 .boxed()
406 }
407}
408
409struct ConfigurationView {
410 state: gpui::Model<State>,
411 loading_models_task: Option<Task<()>>,
412}
413
414impl ConfigurationView {
415 pub fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
416 let loading_models_task = Some(cx.spawn({
417 let state = state.clone();
418 |this, mut cx| async move {
419 if let Some(task) = state
420 .update(&mut cx, |state, cx| state.authenticate(cx))
421 .log_err()
422 {
423 task.await.log_err();
424 }
425 this.update(&mut cx, |this, cx| {
426 this.loading_models_task = None;
427 cx.notify();
428 })
429 .log_err();
430 }
431 }));
432
433 Self {
434 state,
435 loading_models_task,
436 }
437 }
438
439 fn retry_connection(&self, cx: &mut WindowContext) {
440 self.state
441 .update(cx, |state, cx| state.fetch_models(cx))
442 .detach_and_log_err(cx);
443 }
444}
445
446impl Render for ConfigurationView {
447 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
448 let is_authenticated = self.state.read(cx).is_authenticated();
449
450 let ollama_intro = "Get up and running with Llama 3.3, Mistral, Gemma 2, and other large language models with Ollama.";
451 let ollama_reqs =
452 "Ollama must be running with at least one model installed to use it in the assistant.";
453
454 let mut inline_code_bg = cx.theme().colors().editor_background;
455 inline_code_bg.fade_out(0.5);
456
457 if self.loading_models_task.is_some() {
458 div().child(Label::new("Loading models...")).into_any()
459 } else {
460 v_flex()
461 .size_full()
462 .gap_3()
463 .child(
464 v_flex()
465 .size_full()
466 .gap_2()
467 .p_1()
468 .child(Label::new(ollama_intro))
469 .child(Label::new(ollama_reqs))
470 .child(
471 h_flex()
472 .gap_0p5()
473 .child(Label::new("Once installed, try "))
474 .child(
475 div()
476 .bg(inline_code_bg)
477 .px_1p5()
478 .rounded_md()
479 .child(Label::new("ollama run llama3.2")),
480 ),
481 ),
482 )
483 .child(
484 h_flex()
485 .w_full()
486 .pt_2()
487 .justify_between()
488 .gap_2()
489 .child(
490 h_flex()
491 .w_full()
492 .gap_2()
493 .map(|this| {
494 if is_authenticated {
495 this.child(
496 Button::new("ollama-site", "Ollama")
497 .style(ButtonStyle::Subtle)
498 .icon(IconName::ExternalLink)
499 .icon_size(IconSize::XSmall)
500 .icon_color(Color::Muted)
501 .on_click(move |_, cx| cx.open_url(OLLAMA_SITE))
502 .into_any_element(),
503 )
504 } else {
505 this.child(
506 Button::new(
507 "download_ollama_button",
508 "Download Ollama",
509 )
510 .style(ButtonStyle::Subtle)
511 .icon(IconName::ExternalLink)
512 .icon_size(IconSize::XSmall)
513 .icon_color(Color::Muted)
514 .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
515 .into_any_element(),
516 )
517 }
518 })
519 .child(
520 Button::new("view-models", "All Models")
521 .style(ButtonStyle::Subtle)
522 .icon(IconName::ExternalLink)
523 .icon_size(IconSize::XSmall)
524 .icon_color(Color::Muted)
525 .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
526 ),
527 )
528 .child(if is_authenticated {
529 // This is only a button to ensure the spacing is correct
530 // it should stay disabled
531 ButtonLike::new("connected")
532 .disabled(true)
533 // Since this won't ever be clickable, we can use the arrow cursor
534 .cursor_style(gpui::CursorStyle::Arrow)
535 .child(
536 h_flex()
537 .gap_2()
538 .child(Indicator::dot().color(Color::Success))
539 .child(Label::new("Connected"))
540 .into_any_element(),
541 )
542 .into_any_element()
543 } else {
544 Button::new("retry_ollama_models", "Connect")
545 .icon_position(IconPosition::Start)
546 .icon(IconName::ArrowCircle)
547 .on_click(cx.listener(move |this, _, cx| this.retry_connection(cx)))
548 .into_any_element()
549 }),
550 )
551 .into_any()
552 }
553 }
554}