server.go

  1package server
  2
  3import (
  4	"context"
  5	"fmt"
  6	"log/slog"
  7	"net"
  8	"net/rpc"
  9	"os"
 10	"os/user"
 11	"path/filepath"
 12	"runtime"
 13	"strings"
 14	"sync/atomic"
 15	"time"
 16
 17	"github.com/charmbracelet/crush/internal/app"
 18	"github.com/charmbracelet/crush/internal/config"
 19	"github.com/charmbracelet/crush/internal/csync"
 20
 21	msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc/v2"
 22)
 23
 24// ErrServerClosed is returned when the server is closed.
 25var ErrServerClosed = fmt.Errorf("server closed")
 26
 27// InstanceState represents the state of a running [app.App] instance.
 28type InstanceState uint8
 29
 30const (
 31	// InstanceStateCreated indicates that the instance has been created but not yet started.
 32	InstanceStateCreated InstanceState = iota
 33	// InstanceStateStarted indicates that the instance is currently running.
 34	InstanceStateStarted
 35	// InstanceStateStopped indicates that the instance has been stopped.
 36	InstanceStateStopped
 37)
 38
 39// Instance represents a running [app.App] instance with its associated
 40// resources and state.
 41type Instance struct {
 42	*app.App
 43	State InstanceState
 44	id    string
 45	path  string
 46}
 47
 48// ID returns the unique identifier of the instance.
 49func (i *Instance) ID() string {
 50	return i.id
 51}
 52
 53// Path returns the filesystem path associated with the instance.
 54func (i *Instance) Path() string {
 55	return i.path
 56}
 57
 58// DefaultAddr returns the default address path for the Crush server based on
 59// the operating system.
 60func DefaultAddr() string {
 61	sock := "crush.sock"
 62	user, err := user.Current()
 63	if err == nil && user.Uid != "" {
 64		sock = fmt.Sprintf("crush-%s.sock", user.Uid)
 65	}
 66	if runtime.GOOS == "windows" {
 67		return fmt.Sprintf(`\\.\pipe\%s`, sock)
 68	}
 69	return filepath.Join(os.TempDir(), sock)
 70}
 71
 72// Server represents a Crush server instance bound to a specific address.
 73type Server struct {
 74	// Addr can be a TCP address, a Unix socket path, or a Windows named pipe.
 75	Addr string
 76
 77	// instances is a map of running applications managed by the server.
 78	instances *csync.Map[string, *Instance]
 79	// listeners is the network listener for the server.
 80	listeners *csync.Map[*net.Listener, struct{}]
 81	cfg       *config.Config
 82	logger    *slog.Logger
 83
 84	shutdown atomic.Bool
 85}
 86
 87// SetLogger sets the logger for the server.
 88func (s *Server) SetLogger(logger *slog.Logger) {
 89	s.logger = logger
 90}
 91
 92// DefaultServer returns a new [Server] instance with the default address.
 93func DefaultServer(cfg *config.Config) *Server {
 94	return NewServer(cfg, "unix", DefaultAddr())
 95}
 96
 97// NewServer is a helper to create a new [Server] instance with the given
 98// address. On Windows, if the address is not a "tcp" address, it will be
 99// converted to a named pipe format.
100func NewServer(cfg *config.Config, network, address string) *Server {
101	if runtime.GOOS == "windows" && !strings.HasPrefix(address, "tcp") &&
102		!strings.HasPrefix(address, `\\.\pipe\`) {
103		// On Windows, convert to named pipe format if not TCP
104		// (e.g., "mypipe" -> "\\.\pipe\mypipe")
105		address = fmt.Sprintf(`\\.\pipe\%s`, address)
106	}
107
108	s := new(Server)
109	s.Addr = address
110	s.cfg = cfg
111	s.instances = csync.NewMap[string, *Instance]()
112	rpc.Register(&ServerProto{s})
113	return s
114}
115
116// Serve accepts incoming connections on the listener.
117func (s *Server) Serve(ln net.Listener) error {
118	if s.listeners == nil {
119		s.listeners = csync.NewMap[*net.Listener, struct{}]()
120	}
121	s.listeners.Set(&ln, struct{}{})
122
123	var tempDelay time.Duration // how long to sleep on accept failure
124	for {
125		conn, err := ln.Accept()
126		if err != nil {
127			if s.shuttingDown() {
128				return ErrServerClosed
129			}
130			if ne, ok := err.(net.Error); ok && ne.Temporary() {
131				if tempDelay == 0 {
132					tempDelay = 5 * time.Millisecond
133				} else {
134					tempDelay *= 2
135				}
136				if max := 1 * time.Second; tempDelay > max {
137					tempDelay = max
138				}
139				time.Sleep(tempDelay)
140				continue
141			}
142			return fmt.Errorf("failed to accept connection: %w", err)
143		}
144		go s.handleConn(conn)
145	}
146}
147
148// ListenAndServe starts the server and begins accepting connections.
149func (s *Server) ListenAndServe() error {
150	ln, err := listen("unix", s.Addr)
151	if err != nil {
152		return fmt.Errorf("failed to listen on %s: %w", s.Addr, err)
153	}
154	return s.Serve(ln)
155}
156
157// Close force close all listeners and connections.
158func (s *Server) Close() error {
159	s.shutdown.Store(true)
160	var firstErr error
161	for k := range s.listeners.Seq2() {
162		if err := (*k).Close(); err != nil && firstErr == nil {
163			firstErr = err
164		}
165		s.listeners.Del(k)
166	}
167	return firstErr
168}
169
170// Shutdown gracefully shuts down the server without interrupting active
171// connections. It stops accepting new connections and waits for existing
172// connections to finish.
173func (s *Server) Shutdown(ctx context.Context) error {
174	// TODO: implement graceful shutdown
175	return s.Close()
176}
177
178func (s *Server) handleConn(conn net.Conn) {
179	s.info("accepted connection", "remote_addr", conn.LocalAddr())
180	codec := &ServerCodec{
181		MsgpackCodec: msgpackrpc.NewCodec(true, true, conn),
182		logger: s.logger.With(
183			slog.String("remote_addr", conn.RemoteAddr().String()),
184			slog.String("local_addr", conn.LocalAddr().String()),
185		),
186	}
187	rpc.ServeCodec(codec)
188}
189
190func (s *Server) shuttingDown() bool {
191	return s.shutdown.Load()
192}
193
194func (s *Server) info(msg string, args ...any) {
195	if s.logger != nil {
196		s.logger.Info(msg, args...)
197	}
198}
199
200func (s *Server) debug(msg string, args ...any) {
201	if s.logger != nil {
202		s.logger.Debug(msg, args...)
203	}
204}
205
206func (s *Server) error(msg string, args ...any) {
207	if s.logger != nil {
208		s.logger.Error(msg, args...)
209	}
210}
211
212func (s *Server) warn(msg string, args ...any) {
213	if s.logger != nil {
214		s.logger.Warn(msg, args...)
215	}
216}