app.go

  1package app
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"errors"
  7	"fmt"
  8	"log/slog"
  9	"maps"
 10	"sync"
 11	"time"
 12
 13	tea "github.com/charmbracelet/bubbletea/v2"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/csync"
 16	"github.com/charmbracelet/crush/internal/db"
 17	"github.com/charmbracelet/crush/internal/format"
 18	"github.com/charmbracelet/crush/internal/history"
 19	"github.com/charmbracelet/crush/internal/llm/agent"
 20	"github.com/charmbracelet/crush/internal/log"
 21	"github.com/charmbracelet/crush/internal/pubsub"
 22
 23	"github.com/charmbracelet/crush/internal/lsp"
 24	"github.com/charmbracelet/crush/internal/lsp/watcher"
 25	"github.com/charmbracelet/crush/internal/message"
 26	"github.com/charmbracelet/crush/internal/permission"
 27	"github.com/charmbracelet/crush/internal/session"
 28)
 29
 30type App struct {
 31	Sessions    session.Service
 32	Messages    message.Service
 33	History     history.Service
 34	Permissions permission.Service
 35
 36	CoderAgent agent.Service
 37
 38	LSPClients map[string]*lsp.Client
 39
 40	clientsMutex sync.RWMutex
 41
 42	watcherCancelFuncs *csync.Slice[context.CancelFunc]
 43	lspWatcherWG       sync.WaitGroup
 44
 45	config *config.Config
 46
 47	serviceEventsWG *sync.WaitGroup
 48	eventsCtx       context.Context
 49	events          chan tea.Msg
 50	tuiWG           *sync.WaitGroup
 51
 52	// global context and cleanup functions
 53	globalCtx    context.Context
 54	cleanupFuncs []func() error
 55}
 56
 57// New initializes a new applcation instance.
 58func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
 59	q := db.New(conn)
 60	sessions := session.NewService(q)
 61	messages := message.NewService(q)
 62	files := history.NewService(q, conn)
 63	skipPermissionsRequests := cfg.Permissions != nil && cfg.Permissions.SkipRequests
 64	allowedTools := []string{}
 65	if cfg.Permissions != nil && cfg.Permissions.AllowedTools != nil {
 66		allowedTools = cfg.Permissions.AllowedTools
 67	}
 68
 69	app := &App{
 70		Sessions:    sessions,
 71		Messages:    messages,
 72		History:     files,
 73		Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests, allowedTools),
 74		LSPClients:  make(map[string]*lsp.Client),
 75
 76		globalCtx: ctx,
 77
 78		config: cfg,
 79
 80		watcherCancelFuncs: csync.NewSlice[context.CancelFunc](),
 81
 82		events:          make(chan tea.Msg, 100),
 83		serviceEventsWG: &sync.WaitGroup{},
 84		tuiWG:           &sync.WaitGroup{},
 85	}
 86
 87	app.setupEvents()
 88
 89	// Start the global watcher
 90	if err := watcher.Start(); err != nil {
 91		return nil, fmt.Errorf("app: %w", err)
 92	}
 93
 94	// Initialize LSP clients in the background.
 95	app.initLSPClients(ctx)
 96
 97	// cleanup database upon app shutdown
 98	app.cleanupFuncs = append(app.cleanupFuncs, conn.Close)
 99
100	// TODO: remove the concept of agent config, most likely.
101	if cfg.IsConfigured() {
102		if err := app.InitCoderAgent(); err != nil {
103			return nil, fmt.Errorf("failed to initialize coder agent: %w", err)
104		}
105	} else {
106		slog.Warn("No agent configuration found")
107	}
108	return app, nil
109}
110
111// Config returns the application configuration.
112func (app *App) Config() *config.Config {
113	return app.config
114}
115
116// RunNonInteractive handles the execution flow when a prompt is provided via
117// CLI flag.
118func (app *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool) error {
119	slog.Info("Running in non-interactive mode")
120
121	ctx, cancel := context.WithCancel(ctx)
122	defer cancel()
123
124	// Start spinner if not in quiet mode.
125	var spinner *format.Spinner
126	if !quiet {
127		spinner = format.NewSpinner(ctx, cancel, "Generating")
128		spinner.Start()
129	}
130
131	// Helper function to stop spinner once.
132	stopSpinner := func() {
133		if !quiet && spinner != nil {
134			spinner.Stop()
135			spinner = nil
136		}
137	}
138	defer stopSpinner()
139
140	const maxPromptLengthForTitle = 100
141	titlePrefix := "Non-interactive: "
142	var titleSuffix string
143
144	if len(prompt) > maxPromptLengthForTitle {
145		titleSuffix = prompt[:maxPromptLengthForTitle] + "..."
146	} else {
147		titleSuffix = prompt
148	}
149	title := titlePrefix + titleSuffix
150
151	sess, err := app.Sessions.Create(ctx, title)
152	if err != nil {
153		return fmt.Errorf("failed to create session for non-interactive mode: %w", err)
154	}
155	slog.Info("Created session for non-interactive run", "session_id", sess.ID)
156
157	// Automatically approve all permission requests for this non-interactive session
158	app.Permissions.AutoApproveSession(sess.ID)
159
160	done, err := app.CoderAgent.Run(ctx, sess.ID, prompt)
161	if err != nil {
162		return fmt.Errorf("failed to start agent processing stream: %w", err)
163	}
164
165	messageEvents := app.Messages.Subscribe(ctx)
166	messageReadBytes := make(map[string]int)
167
168	for {
169		select {
170		case result := <-done:
171			stopSpinner()
172
173			if result.Error != nil {
174				if errors.Is(result.Error, context.Canceled) || errors.Is(result.Error, agent.ErrRequestCancelled) {
175					slog.Info("Non-interactive: agent processing cancelled", "session_id", sess.ID)
176					return nil
177				}
178				return fmt.Errorf("agent processing failed: %w", result.Error)
179			}
180
181			msgContent := result.Message.Content().String()
182			readBts := messageReadBytes[result.Message.ID]
183
184			if len(msgContent) < readBts {
185				slog.Error("Non-interactive: message content is shorter than read bytes", "message_length", len(msgContent), "read_bytes", readBts)
186				return fmt.Errorf("message content is shorter than read bytes: %d < %d", len(msgContent), readBts)
187			}
188			fmt.Println(msgContent[readBts:])
189			messageReadBytes[result.Message.ID] = len(msgContent)
190
191			slog.Info("Non-interactive: run completed", "session_id", sess.ID)
192			return nil
193
194		case event := <-messageEvents:
195			msg := event.Payload
196			if msg.SessionID == sess.ID && msg.Role == message.Assistant && len(msg.Parts) > 0 {
197				stopSpinner()
198
199				content := msg.Content().String()
200				readBytes := messageReadBytes[msg.ID]
201
202				if len(content) < readBytes {
203					slog.Error("Non-interactive: message content is shorter than read bytes", "message_length", len(content), "read_bytes", readBytes)
204					return fmt.Errorf("message content is shorter than read bytes: %d < %d", len(content), readBytes)
205				}
206
207				part := content[readBytes:]
208				fmt.Print(part)
209				messageReadBytes[msg.ID] = len(content)
210			}
211
212		case <-ctx.Done():
213			stopSpinner()
214			return ctx.Err()
215		}
216	}
217}
218
219func (app *App) UpdateAgentModel() error {
220	return app.CoderAgent.UpdateModel()
221}
222
223func (app *App) setupEvents() {
224	ctx, cancel := context.WithCancel(app.globalCtx)
225	app.eventsCtx = ctx
226	setupSubscriber(ctx, app.serviceEventsWG, "sessions", app.Sessions.Subscribe, app.events)
227	setupSubscriber(ctx, app.serviceEventsWG, "messages", app.Messages.Subscribe, app.events)
228	setupSubscriber(ctx, app.serviceEventsWG, "permissions", app.Permissions.Subscribe, app.events)
229	setupSubscriber(ctx, app.serviceEventsWG, "permissions-notifications", app.Permissions.SubscribeNotifications, app.events)
230	setupSubscriber(ctx, app.serviceEventsWG, "history", app.History.Subscribe, app.events)
231	setupSubscriber(ctx, app.serviceEventsWG, "mcp", agent.SubscribeMCPEvents, app.events)
232	setupSubscriber(ctx, app.serviceEventsWG, "lsp", SubscribeLSPEvents, app.events)
233	cleanupFunc := func() error {
234		cancel()
235		app.serviceEventsWG.Wait()
236		return nil
237	}
238	app.cleanupFuncs = append(app.cleanupFuncs, cleanupFunc)
239}
240
241func setupSubscriber[T any](
242	ctx context.Context,
243	wg *sync.WaitGroup,
244	name string,
245	subscriber func(context.Context) <-chan pubsub.Event[T],
246	outputCh chan<- tea.Msg,
247) {
248	wg.Go(func() {
249		subCh := subscriber(ctx)
250		for {
251			select {
252			case event, ok := <-subCh:
253				if !ok {
254					slog.Debug("subscription channel closed", "name", name)
255					return
256				}
257				var msg tea.Msg = event
258				select {
259				case outputCh <- msg:
260				case <-time.After(2 * time.Second):
261					slog.Warn("message dropped due to slow consumer", "name", name)
262				case <-ctx.Done():
263					slog.Debug("subscription cancelled", "name", name)
264					return
265				}
266			case <-ctx.Done():
267				slog.Debug("subscription cancelled", "name", name)
268				return
269			}
270		}
271	})
272}
273
274func (app *App) InitCoderAgent() error {
275	coderAgentCfg := app.config.Agents["coder"]
276	if coderAgentCfg.ID == "" {
277		return fmt.Errorf("coder agent configuration is missing")
278	}
279	var err error
280	app.CoderAgent, err = agent.NewAgent(
281		app.globalCtx,
282		coderAgentCfg,
283		app.Permissions,
284		app.Sessions,
285		app.Messages,
286		app.History,
287		app.LSPClients,
288	)
289	if err != nil {
290		slog.Error("Failed to create coder agent", "err", err)
291		return err
292	}
293
294	// Add MCP client cleanup to shutdown process
295	app.cleanupFuncs = append(app.cleanupFuncs, agent.CloseMCPClients)
296
297	setupSubscriber(app.eventsCtx, app.serviceEventsWG, "coderAgent", app.CoderAgent.Subscribe, app.events)
298	return nil
299}
300
301// Subscribe sends events to the TUI as tea.Msgs.
302func (app *App) Subscribe(program *tea.Program) {
303	defer log.RecoverPanic("app.Subscribe", func() {
304		slog.Info("TUI subscription panic: attempting graceful shutdown")
305		program.Quit()
306	})
307
308	app.tuiWG.Add(1)
309	tuiCtx, tuiCancel := context.WithCancel(app.globalCtx)
310	app.cleanupFuncs = append(app.cleanupFuncs, func() error {
311		slog.Debug("Cancelling TUI message handler")
312		tuiCancel()
313		app.tuiWG.Wait()
314		return nil
315	})
316	defer app.tuiWG.Done()
317
318	for {
319		select {
320		case <-tuiCtx.Done():
321			slog.Debug("TUI message handler shutting down")
322			return
323		case msg, ok := <-app.events:
324			if !ok {
325				slog.Debug("TUI message channel closed")
326				return
327			}
328			program.Send(msg)
329		}
330	}
331}
332
333// Shutdown performs a graceful shutdown of the application.
334func (app *App) Shutdown() {
335	if app.CoderAgent != nil {
336		app.CoderAgent.CancelAll()
337	}
338
339	for cancel := range app.watcherCancelFuncs.Seq() {
340		cancel()
341	}
342
343	// Wait for all LSP watchers to finish.
344	app.lspWatcherWG.Wait()
345
346	// Get all LSP clients.
347	app.clientsMutex.RLock()
348	clients := make(map[string]*lsp.Client, len(app.LSPClients))
349	maps.Copy(clients, app.LSPClients)
350	app.clientsMutex.RUnlock()
351
352	// Shutdown all LSP clients.
353	for name, client := range clients {
354		shutdownCtx, cancel := context.WithTimeout(app.globalCtx, 5*time.Second)
355		if err := client.Shutdown(shutdownCtx); err != nil {
356			slog.Error("Failed to shutdown LSP client", "name", name, "error", err)
357		}
358		cancel()
359	}
360
361	// Shutdown the global watcher
362	watcher.Shutdown()
363
364	// Call call cleanup functions.
365	for _, cleanup := range app.cleanupFuncs {
366		if cleanup != nil {
367			if err := cleanup(); err != nil {
368				slog.Error("Failed to cleanup app properly on shutdown", "error", err)
369			}
370		}
371	}
372}