client: refresh session cookies

Amolith created

Use the standard HTTP cookie jar for password sessions so cookie expiry
is handled by normal cookie rules. Retry once after an auth failure by
expiring the stale session cookie and logging in again from the
configured credentials.

Add tests covering cookie expiry and server-side rejection of stale
sessions.

Change summary

internal/silverbullet/client.go      | 156 +++++++++++++++++++++--------
internal/silverbullet/client_test.go | 127 ++++++++++++++++++++++++
2 files changed, 240 insertions(+), 43 deletions(-)

Detailed changes

internal/silverbullet/client.go 🔗

@@ -12,9 +12,11 @@ import (
 	"fmt"
 	"io"
 	"net/http"
+	"net/http/cookiejar"
 	"net/url"
 	"strconv"
 	"strings"
+	"sync"
 	"time"
 )
 
@@ -33,18 +35,18 @@ type Client struct {
 	baseURL    string
 	httpClient *http.Client
 	auth       Auth
-	// cached session from password login
-	cachedCookieName string
-	cachedCookieVal  string
+	authMu     sync.Mutex
 }
 
 // New creates a new SilverBullet client.
 // The baseURL is normalized to remove trailing slashes.
 func New(baseURL string, auth Auth) *Client {
+	jar, _ := cookiejar.New(nil)
 	return &Client{
 		baseURL: strings.TrimRight(baseURL, "/"),
 		httpClient: &http.Client{
 			Timeout: 6 * time.Hour, // long enough for lua_script with X-Timeout up to 21600s
+			Jar:     jar,
 			CheckRedirect: func(req *http.Request, via []*http.Request) error {
 				return http.ErrUseLastResponse
 			},
@@ -92,11 +94,8 @@ func (c *Client) ExecuteLua(ctx context.Context, script string, timeout int) (*L
 
 	req.Header.Set("Content-Type", "text/plain")
 	req.Header.Set("X-Timeout", strconv.Itoa(timeout))
-	if err := c.setAuth(req); err != nil {
-		return nil, fmt.Errorf("setting auth: %w", err)
-	}
 
-	resp, err := c.httpClient.Do(req)
+	resp, err := c.do(req)
 	if err != nil {
 		return nil, fmt.Errorf("executing lua: %w", err)
 	}
@@ -135,11 +134,7 @@ func (c *Client) Screenshot(ctx context.Context) ([]byte, error) {
 		return nil, fmt.Errorf("creating request: %w", err)
 	}
 
-	if err := c.setAuth(req); err != nil {
-		return nil, fmt.Errorf("setting auth: %w", err)
-	}
-
-	resp, err := c.httpClient.Do(req)
+	resp, err := c.do(req)
 	if err != nil {
 		return nil, fmt.Errorf("fetching screenshot: %w", err)
 	}
@@ -179,10 +174,6 @@ func (c *Client) ConsoleLogs(ctx context.Context, limit int, since int64) (*Logs
 		return nil, fmt.Errorf("creating request: %w", err)
 	}
 
-	if err := c.setAuth(req); err != nil {
-		return nil, fmt.Errorf("setting auth: %w", err)
-	}
-
 	q := req.URL.Query()
 	q.Set("limit", strconv.Itoa(limit))
 	if since > 0 {
@@ -190,7 +181,7 @@ func (c *Client) ConsoleLogs(ctx context.Context, limit int, since int64) (*Logs
 	}
 	req.URL.RawQuery = q.Encode()
 
-	resp, err := c.httpClient.Do(req)
+	resp, err := c.do(req)
 	if err != nil {
 		return nil, fmt.Errorf("fetching logs: %w", err)
 	}
@@ -218,9 +209,62 @@ func (c *Client) resolveURL(path string) (string, error) {
 	return url.JoinPath(c.baseURL, path)
 }
 
+// do sends an authenticated request.
+// If a password session is rejected, it discards the cached cookie, logs in again,
+// and retries the request once.
+func (c *Client) do(req *http.Request) (*http.Response, error) {
+	if err := c.setAuth(req); err != nil {
+		return nil, fmt.Errorf("setting auth: %w", err)
+	}
+
+	resp, err := c.httpClient.Do(req)
+	if err != nil {
+		return nil, err
+	}
+	if !c.usesPasswordAuth() || !isAuthFailure(resp.StatusCode) {
+		return resp, nil
+	}
+
+	_, _ = io.Copy(io.Discard, resp.Body)
+	_ = resp.Body.Close()
+
+	if req.Body != nil && req.GetBody == nil {
+		return nil, fmt.Errorf("authentication failed and request body cannot be replayed")
+	}
+	if err := c.expireSessionCookies(req.URL); err != nil {
+		return nil, fmt.Errorf("expiring session cookie: %w", err)
+	}
+
+	retryReq := req.Clone(req.Context())
+	retryReq.Header = req.Header.Clone()
+	retryReq.Header.Del("Cookie")
+	if req.GetBody != nil {
+		retryBody, err := req.GetBody()
+		if err != nil {
+			return nil, fmt.Errorf("recreating request body: %w", err)
+		}
+		retryReq.Body = retryBody
+	}
+
+	if err := c.setAuth(retryReq); err != nil {
+		return nil, fmt.Errorf("refreshing auth: %w", err)
+	}
+
+	return c.httpClient.Do(retryReq)
+}
+
+func (c *Client) usesPasswordAuth() bool {
+	return c.auth.User != "" && c.auth.Pass != ""
+}
+
+func isAuthFailure(status int) bool {
+	return status == http.StatusUnauthorized || (status >= 300 && status < 400)
+}
+
 // setAuth sets authentication headers on the request.
 // Bearer token goes on the Authorization header.
-// Password auth logs in via POST /.auth and sends the session cookie.
+// Password auth logs in via POST /.auth; the client's cookie jar sends the
+// resulting session cookie on subsequent requests until it expires.
 // Both can coexist — they use different headers.
 func (c *Client) setAuth(req *http.Request) error {
 	// Bearer token on Authorization header
@@ -229,7 +273,7 @@ func (c *Client) setAuth(req *http.Request) error {
 	}
 
 	// Password auth via session cookie
-	if c.auth.User != "" && c.auth.Pass != "" {
+	if c.usesPasswordAuth() {
 		if err := c.ensureSessionCookie(req); err != nil {
 			return fmt.Errorf("session login: %w", err)
 		}
@@ -238,27 +282,29 @@ func (c *Client) setAuth(req *http.Request) error {
 	return nil
 }
 
-// ensureSessionCookie logs in via POST /.auth if needed and sets the session cookie.
+// ensureSessionCookie logs in via POST /.auth if the cookie jar has no current
+// auth cookie. The jar applies normal cookie expiration rules.
 func (c *Client) ensureSessionCookie(req *http.Request) error {
-	// Use cached session if available
-	if c.cachedCookieName != "" && c.cachedCookieVal != "" {
-		req.AddCookie(&http.Cookie{
-			Name:  c.cachedCookieName,
-			Value: c.cachedCookieVal,
-		})
+	loginURL, err := c.resolveURL("/.auth")
+	if err != nil {
+		return fmt.Errorf("resolving login URL: %w", err)
+	}
+	loginCookieURL, err := url.Parse(loginURL)
+	if err != nil {
+		return fmt.Errorf("parsing login URL: %w", err)
+	}
+
+	c.authMu.Lock()
+	defer c.authMu.Unlock()
+
+	if c.hasAuthCookie(loginCookieURL) {
 		return nil
 	}
 
-	// Log in to get session cookie
 	form := url.Values{}
 	form.Set("username", c.auth.User)
 	form.Set("password", c.auth.Pass)
 
-	loginURL, err := c.resolveURL("/.auth")
-	if err != nil {
-		return fmt.Errorf("resolving login URL: %w", err)
-	}
-
 	loginReq, err := http.NewRequestWithContext(req.Context(), http.MethodPost, loginURL, strings.NewReader(form.Encode()))
 	if err != nil {
 		return fmt.Errorf("creating login request: %w", err)
@@ -274,19 +320,43 @@ func (c *Client) ensureSessionCookie(req *http.Request) error {
 	if loginResp.StatusCode != http.StatusOK {
 		return fmt.Errorf("login failed with status %d", loginResp.StatusCode)
 	}
+	if !c.hasAuthCookie(loginCookieURL) {
+		return fmt.Errorf("login succeeded but no auth cookie returned")
+	}
 
-	// Extract session cookie from Set-Cookie header
-	for _, cookie := range loginResp.Cookies() {
-		if strings.HasPrefix(cookie.Name, "auth_") {
-			c.cachedCookieName = cookie.Name
-			c.cachedCookieVal = cookie.Value
-			req.AddCookie(&http.Cookie{
-				Name:  cookie.Name,
-				Value: cookie.Value,
-			})
-			return nil
+	return nil
+}
+
+func (c *Client) hasAuthCookie(u *url.URL) bool {
+	if c.httpClient.Jar == nil {
+		return false
+	}
+	for _, cookie := range c.httpClient.Jar.Cookies(u) {
+		if strings.HasPrefix(cookie.Name, "auth_") && cookie.Value != "" {
+			return true
 		}
 	}
+	return false
+}
+
+func (c *Client) expireSessionCookies(u *url.URL) error {
+	if c.httpClient.Jar == nil {
+		return nil
+	}
+
+	c.authMu.Lock()
+	defer c.authMu.Unlock()
 
-	return fmt.Errorf("login succeeded but no auth cookie returned")
+	cookies := c.httpClient.Jar.Cookies(u)
+	for _, cookie := range cookies {
+		if strings.HasPrefix(cookie.Name, "auth_") {
+			c.httpClient.Jar.SetCookies(u, []*http.Cookie{{
+				Name:   cookie.Name,
+				Value:  "",
+				Path:   "/",
+				MaxAge: -1,
+			}})
+		}
+	}
+	return nil
 }

internal/silverbullet/client_test.go 🔗

@@ -10,7 +10,9 @@ import (
 	"fmt"
 	"net/http"
 	"net/http/httptest"
+	"sync"
 	"testing"
+	"time"
 )
 
 func newTestServer(mux *http.ServeMux) *httptest.Server {
@@ -312,6 +314,131 @@ func TestBothAuth(t *testing.T) {
 	}
 }
 
+func TestPasswordAuthReloginsAfterCookieExpires(t *testing.T) {
+	mux := http.NewServeMux()
+	var mu sync.Mutex
+	loginCount := 0
+	currentToken := ""
+
+	mux.HandleFunc("/.auth", func(w http.ResponseWriter, r *http.Request) {
+		if r.Method != http.MethodPost {
+			w.WriteHeader(http.StatusMethodNotAllowed)
+			return
+		}
+		if r.FormValue("username") != "testuser" || r.FormValue("password") != "testpass" {
+			w.WriteHeader(http.StatusUnauthorized)
+			return
+		}
+
+		mu.Lock()
+		loginCount++
+		currentToken = fmt.Sprintf("mock-jwt-token-%d", loginCount)
+		token := currentToken
+		mu.Unlock()
+
+		http.SetCookie(w, &http.Cookie{
+			Name:   "auth_session",
+			Value:  token,
+			MaxAge: 1,
+		})
+		w.WriteHeader(http.StatusOK)
+	})
+	mux.HandleFunc("/.runtime/logs", func(w http.ResponseWriter, r *http.Request) {
+		cookie, err := r.Cookie("auth_session")
+		if err != nil {
+			t.Error("expected auth_session cookie, got none")
+			w.WriteHeader(http.StatusUnauthorized)
+			return
+		}
+
+		mu.Lock()
+		expectedToken := currentToken
+		mu.Unlock()
+		if cookie.Value != expectedToken {
+			t.Errorf("expected cookie value %q, got %q", expectedToken, cookie.Value)
+		}
+		_ = json.NewEncoder(w).Encode(LogsResult{Logs: []LogEntry{}})
+	})
+
+	srv := newTestServer(mux)
+	defer srv.Close()
+
+	client := testClientWithBasicAuth(srv.URL)
+	if _, err := client.ConsoleLogs(context.Background(), 1, 0); err != nil {
+		t.Fatalf("first ConsoleLogs failed: %v", err)
+	}
+	time.Sleep(1100 * time.Millisecond)
+	if _, err := client.ConsoleLogs(context.Background(), 1, 0); err != nil {
+		t.Fatalf("second ConsoleLogs failed: %v", err)
+	}
+
+	mu.Lock()
+	defer mu.Unlock()
+	if loginCount != 2 {
+		t.Fatalf("expected 2 logins after cookie expiry, got %d", loginCount)
+	}
+}
+
+func TestPasswordAuthRefreshesRejectedSession(t *testing.T) {
+	mux := http.NewServeMux()
+	var mu sync.Mutex
+	loginCount := 0
+
+	mux.HandleFunc("/.auth", func(w http.ResponseWriter, r *http.Request) {
+		if r.Method != http.MethodPost {
+			w.WriteHeader(http.StatusMethodNotAllowed)
+			return
+		}
+		if r.FormValue("username") != "testuser" || r.FormValue("password") != "testpass" {
+			w.WriteHeader(http.StatusUnauthorized)
+			return
+		}
+
+		mu.Lock()
+		loginCount++
+		token := fmt.Sprintf("mock-jwt-token-%d", loginCount)
+		mu.Unlock()
+
+		http.SetCookie(w, &http.Cookie{
+			Name:  "auth_session",
+			Value: token,
+		})
+		w.WriteHeader(http.StatusOK)
+	})
+	mux.HandleFunc("/.runtime/lua_script", func(w http.ResponseWriter, r *http.Request) {
+		cookie, err := r.Cookie("auth_session")
+		if err != nil {
+			t.Error("expected auth_session cookie, got none")
+			w.WriteHeader(http.StatusUnauthorized)
+			return
+		}
+		if cookie.Value == "mock-jwt-token-1" {
+			w.WriteHeader(http.StatusUnauthorized)
+			return
+		}
+		if cookie.Value != "mock-jwt-token-2" {
+			t.Errorf("expected refreshed cookie value, got %q", cookie.Value)
+		}
+
+		result := LuaResult{Result: json.RawMessage(`"ok"`)}
+		_ = json.NewEncoder(w).Encode(result)
+	})
+
+	srv := newTestServer(mux)
+	defer srv.Close()
+
+	client := testClientWithBasicAuth(srv.URL)
+	if _, err := client.ExecuteLua(context.Background(), "return 'ok'", 30); err != nil {
+		t.Fatalf("ExecuteLua failed: %v", err)
+	}
+
+	mu.Lock()
+	defer mu.Unlock()
+	if loginCount != 2 {
+		t.Fatalf("expected 2 logins after rejected session, got %d", loginCount)
+	}
+}
+
 func TestTrailingSlashURL(t *testing.T) {
 	// Verify that trailing slashes in the base URL are handled correctly
 	client := New("http://localhost:3000/", Auth{})