root.go

  1package cmd
  2
  3import (
  4	"context"
  5	"fmt"
  6	"os"
  7	"sync"
  8	"time"
  9
 10	tea "github.com/charmbracelet/bubbletea"
 11	"github.com/kujtimiihoxha/termai/internal/app"
 12	"github.com/kujtimiihoxha/termai/internal/assets"
 13	"github.com/kujtimiihoxha/termai/internal/config"
 14	"github.com/kujtimiihoxha/termai/internal/db"
 15	"github.com/kujtimiihoxha/termai/internal/llm/agent"
 16	"github.com/kujtimiihoxha/termai/internal/logging"
 17	"github.com/kujtimiihoxha/termai/internal/pubsub"
 18	"github.com/kujtimiihoxha/termai/internal/tui"
 19	zone "github.com/lrstanley/bubblezone"
 20	"github.com/spf13/cobra"
 21)
 22
 23var rootCmd = &cobra.Command{
 24	Use:   "termai",
 25	Short: "A terminal ai assistant",
 26	Long:  `A terminal ai assistant`,
 27	RunE: func(cmd *cobra.Command, args []string) error {
 28		// If the help flag is set, show the help message
 29		if cmd.Flag("help").Changed {
 30			cmd.Help()
 31			return nil
 32		}
 33
 34		// Load the config
 35		debug, _ := cmd.Flags().GetBool("debug")
 36		cwd, _ := cmd.Flags().GetString("cwd")
 37		if cwd != "" {
 38			err := os.Chdir(cwd)
 39			if err != nil {
 40				return fmt.Errorf("failed to change directory: %v", err)
 41			}
 42		}
 43		if cwd == "" {
 44			c, err := os.Getwd()
 45			if err != nil {
 46				return fmt.Errorf("failed to get current working directory: %v", err)
 47			}
 48			cwd = c
 49		}
 50		_, err := config.Load(cwd, debug)
 51		if err != nil {
 52			return err
 53		}
 54
 55		err = assets.WriteAssets()
 56		if err != nil {
 57			logging.Error("Error writing assets: %v", err)
 58		}
 59
 60		// Connect DB, this will also run migrations
 61		conn, err := db.Connect()
 62		if err != nil {
 63			return err
 64		}
 65
 66		// Create main context for the application
 67		ctx, cancel := context.WithCancel(context.Background())
 68		defer cancel()
 69
 70		app := app.New(ctx, conn)
 71
 72		// Set up the TUI
 73		zone.NewGlobal()
 74		program := tea.NewProgram(
 75			tui.New(app),
 76			tea.WithAltScreen(),
 77			tea.WithMouseCellMotion(),
 78		)
 79
 80		// Initialize MCP tools in the background
 81		initMCPTools(ctx, app)
 82
 83		// Setup the subscriptions, this will send services events to the TUI
 84		ch, cancelSubs := setupSubscriptions(app)
 85
 86		// Create a context for the TUI message handler
 87		tuiCtx, tuiCancel := context.WithCancel(ctx)
 88		var tuiWg sync.WaitGroup
 89		tuiWg.Add(1)
 90
 91		// Set up message handling for the TUI
 92		go func() {
 93			defer tuiWg.Done()
 94			defer func() {
 95				if r := recover(); r != nil {
 96					logging.Error("Panic in TUI message handling: %v", r)
 97					attemptTUIRecovery(program)
 98				}
 99			}()
100
101			for {
102				select {
103				case <-tuiCtx.Done():
104					logging.Info("TUI message handler shutting down")
105					return
106				case msg, ok := <-ch:
107					if !ok {
108						logging.Info("TUI message channel closed")
109						return
110					}
111					program.Send(msg)
112				}
113			}
114		}()
115
116		// Cleanup function for when the program exits
117		cleanup := func() {
118			// Shutdown the app
119			app.Shutdown()
120
121			// Cancel subscriptions first
122			cancelSubs()
123
124			// Then cancel TUI message handler
125			tuiCancel()
126
127			// Wait for TUI message handler to finish
128			tuiWg.Wait()
129
130			logging.Info("All goroutines cleaned up")
131		}
132
133		// Run the TUI
134		result, err := program.Run()
135		cleanup()
136
137		if err != nil {
138			logging.Error("TUI error: %v", err)
139			return fmt.Errorf("TUI error: %v", err)
140		}
141
142		logging.Info("TUI exited with result: %v", result)
143		return nil
144	},
145}
146
147// attemptTUIRecovery tries to recover the TUI after a panic
148func attemptTUIRecovery(program *tea.Program) {
149	logging.Info("Attempting to recover TUI after panic")
150
151	// We could try to restart the TUI or gracefully exit
152	// For now, we'll just quit the program to avoid further issues
153	program.Quit()
154}
155
156func initMCPTools(ctx context.Context, app *app.App) {
157	go func() {
158		defer func() {
159			if r := recover(); r != nil {
160				logging.Error("Panic in MCP goroutine: %v", r)
161			}
162		}()
163
164		// Create a context with timeout for the initial MCP tools fetch
165		ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
166		defer cancel()
167
168		// Set this up once with proper error handling
169		agent.GetMcpTools(ctxWithTimeout, app.Permissions)
170		logging.Info("MCP message handling goroutine exiting")
171	}()
172}
173
174func setupSubscriber[T any](
175	ctx context.Context,
176	wg *sync.WaitGroup,
177	name string,
178	subscriber func(context.Context) <-chan pubsub.Event[T],
179	outputCh chan<- tea.Msg,
180) {
181	wg.Add(1)
182	go func() {
183		defer wg.Done()
184		defer func() {
185			if r := recover(); r != nil {
186				logging.Error("Panic in %s subscription goroutine: %v", name, r)
187			}
188		}()
189
190		for {
191			select {
192			case event, ok := <-subscriber(ctx):
193				if !ok {
194					logging.Info("%s subscription channel closed", name)
195					return
196				}
197
198				// Convert generic event to tea.Msg if needed
199				var msg tea.Msg = event
200
201				// Non-blocking send with timeout to prevent deadlocks
202				select {
203				case outputCh <- msg:
204				case <-time.After(500 * time.Millisecond):
205					logging.Warn("%s message dropped due to slow consumer", name)
206				case <-ctx.Done():
207					logging.Info("%s subscription cancelled", name)
208					return
209				}
210			case <-ctx.Done():
211				logging.Info("%s subscription cancelled", name)
212				return
213			}
214		}
215	}()
216}
217
218func setupSubscriptions(app *app.App) (chan tea.Msg, func()) {
219	ch := make(chan tea.Msg, 100)
220	// Add a buffer to prevent blocking
221	wg := sync.WaitGroup{}
222	ctx, cancel := context.WithCancel(context.Background())
223	// Setup each subscription using the helper
224	setupSubscriber(ctx, &wg, "logging", logging.Subscribe, ch)
225	setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch)
226	setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch)
227	setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch)
228
229	// Return channel and a cleanup function
230	cleanupFunc := func() {
231		logging.Info("Cancelling all subscriptions")
232		cancel() // Signal all goroutines to stop
233
234		// Wait with a timeout for all goroutines to complete
235		waitCh := make(chan struct{})
236		go func() {
237			wg.Wait()
238			close(waitCh)
239		}()
240
241		select {
242		case <-waitCh:
243			logging.Info("All subscription goroutines completed successfully")
244		case <-time.After(5 * time.Second):
245			logging.Warn("Timed out waiting for some subscription goroutines to complete")
246		}
247
248		close(ch) // Safe to close after all writers are done or timed out
249	}
250	return ch, cleanupFunc
251}
252
253func Execute() {
254	err := rootCmd.Execute()
255	if err != nil {
256		os.Exit(1)
257	}
258}
259
260func init() {
261	rootCmd.Flags().BoolP("help", "h", false, "Help")
262	rootCmd.Flags().BoolP("debug", "d", false, "Debug")
263	rootCmd.Flags().StringP("cwd", "c", "", "Current working directory")
264}