diff --git a/.ruby-version b/.ruby-version index 15a2799..4f5e697 100644 --- a/.ruby-version +++ b/.ruby-version @@ -1 +1 @@ -3.3.0 +3.4.5 diff --git a/CHANGELOG.md b/CHANGELOG.md index b846823..163d88a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,14 @@ +## v0.1.15 / 2025-08-05 + +* Ensure responses are flushable (preventing issues with SSE) (#87) +* Add host to cache key (#86) +* Add X-Request-Start header (#85) +* Add `LOG_REQUESTS` option to control request logging (#50) + +## v0.1.14 / 2025-06-18 + +* Build with Go 1.24.4 (#81) + ## v0.1.13 / 2025-04-21 * Update deps to address CVEs (#74) diff --git a/README.md b/README.md index 5f4d36e..5da4b77 100644 --- a/README.md +++ b/README.md @@ -92,6 +92,7 @@ environment variables that you can set. | `EAB_KID` | The EAB key identifier to use when provisioning TLS certificates, if required. | None | | `EAB_HMAC_KEY` | The Base64-encoded EAB HMAC key to use when provisioning TLS certificates, if required. | None | | `FORWARD_HEADERS` | Whether to forward X-Forwarded-* headers from the client. | Disabled when running with TLS; enabled otherwise | +| `LOG_REQUESTS` | Log all requests. Set to `0` or `false` to disable request logging | Enabled | | `DEBUG` | Set to `1` or `true` to enable debug logging. | Disabled | To prevent naming clashes with your application's own environment variables, diff --git a/go.mod b/go.mod index cfb595e..0370c17 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/basecamp/thruster -go 1.24.2 +go 1.24.4 require ( github.com/klauspost/compress v1.17.4 diff --git a/internal/cache_handler_test.go b/internal/cache_handler_test.go index 59c7f42..99302a6 100644 --- a/internal/cache_handler_test.go +++ b/internal/cache_handler_test.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -165,6 +166,52 @@ func TestCacheHandler_vary_header(t *testing.T) { assert.Equal(t, "hit", resp.Header().Get("X-Cache")) } +func TestCacheHandler_different_hosts(t *testing.T) { + cache := newTestCache() + handler := NewCacheHandler(cache, 1024, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + host := r.Header.Get("Host") + w.Header().Set("Cache-Control", "public, max-age=600") + w.Write([]byte(host)) + })) + + doReq := func(url string) *httptest.ResponseRecorder { + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", url, nil) + host := strings.Split(url, "://")[1] + r.Header.Set("Host", host) + handler.ServeHTTP(w, r) + return w + } + + resp := doReq("https://example.com") + assert.Equal(t, "example.com", resp.Body.String()) + assert.Equal(t, "miss", resp.Header().Get("X-Cache")) + + resp = doReq("https://example.com") + assert.Equal(t, "example.com", resp.Body.String()) + assert.Equal(t, "hit", resp.Header().Get("X-Cache")) + + resp = doReq("https://another.com") + assert.Equal(t, "another.com", resp.Body.String()) + assert.Equal(t, "miss", resp.Header().Get("X-Cache")) + + resp = doReq("https://another.com") + assert.Equal(t, "another.com", resp.Body.String()) + assert.Equal(t, "hit", resp.Header().Get("X-Cache")) + + resp = doReq("https://example.com/test") + assert.Equal(t, "example.com/test", resp.Body.String()) + assert.Equal(t, "miss", resp.Header().Get("X-Cache")) + + resp = doReq("https://another.com/test") + assert.Equal(t, "another.com/test", resp.Body.String()) + assert.Equal(t, "miss", resp.Header().Get("X-Cache")) + + resp = doReq("https://another.com/test") + assert.Equal(t, "another.com/test", resp.Body.String()) + assert.Equal(t, "hit", resp.Header().Get("X-Cache")) +} + func TestCacheHandler_range_requests_are_not_cached(t *testing.T) { cache := newTestCache() diff --git a/internal/cacheable_response.go b/internal/cacheable_response.go index b6906aa..6f9f6f4 100644 --- a/internal/cacheable_response.go +++ b/internal/cacheable_response.go @@ -75,6 +75,13 @@ func (c *CacheableResponse) WriteHeader(statusCode int) { c.headersWritten = true } +func (c *CacheableResponse) Flush() { + flusher, ok := c.responseWriter.(http.Flusher) + if ok { + flusher.Flush() + } +} + func (c *CacheableResponse) CacheStatus() (bool, time.Time) { if c.stasher.Overflowed() { return false, time.Time{} diff --git a/internal/config.go b/internal/config.go index 655f9eb..f715e8e 100644 --- a/internal/config.go +++ b/internal/config.go @@ -33,7 +33,8 @@ const ( defaultHttpReadTimeout = 30 * time.Second defaultHttpWriteTimeout = 30 * time.Second - defaultLogLevel = slog.LevelInfo + defaultLogLevel = slog.LevelInfo + defaultLogRequests = true ) type Config struct { @@ -62,7 +63,8 @@ type Config struct { ForwardHeaders bool - LogLevel slog.Level + LogLevel slog.Level + LogRequests bool } func NewConfig() (*Config, error) { @@ -99,7 +101,8 @@ func NewConfig() (*Config, error) { HttpReadTimeout: getEnvDuration("HTTP_READ_TIMEOUT", defaultHttpReadTimeout), HttpWriteTimeout: getEnvDuration("HTTP_WRITE_TIMEOUT", defaultHttpWriteTimeout), - LogLevel: logLevel, + LogLevel: logLevel, + LogRequests: getEnvBool("LOG_REQUESTS", defaultLogRequests), } config.ForwardHeaders = getEnvBool("FORWARD_HEADERS", !config.HasTLS()) diff --git a/internal/config_test.go b/internal/config_test.go index 7446693..b8ad054 100644 --- a/internal/config_test.go +++ b/internal/config_test.go @@ -116,6 +116,7 @@ func TestConfig_override_defaults_with_env_vars(t *testing.T) { usingEnvVar(t, "GZIP_COMPRESSION_ENABLED", "0") usingEnvVar(t, "DEBUG", "1") usingEnvVar(t, "ACME_DIRECTORY", "https://acme-staging-v02.api.letsencrypt.org/directory") + usingEnvVar(t, "LOG_REQUESTS", "false") c, err := NewConfig() require.NoError(t, err) @@ -127,6 +128,7 @@ func TestConfig_override_defaults_with_env_vars(t *testing.T) { assert.Equal(t, false, c.GzipCompressionEnabled) assert.Equal(t, slog.LevelDebug, c.LogLevel) assert.Equal(t, "https://acme-staging-v02.api.letsencrypt.org/directory", c.ACMEDirectoryURL) + assert.Equal(t, false, c.LogRequests) } func TestConfig_override_defaults_with_env_vars_using_prefix(t *testing.T) { @@ -136,6 +138,7 @@ func TestConfig_override_defaults_with_env_vars_using_prefix(t *testing.T) { usingEnvVar(t, "THRUSTER_HTTP_READ_TIMEOUT", "5") usingEnvVar(t, "THRUSTER_X_SENDFILE_ENABLED", "0") usingEnvVar(t, "THRUSTER_DEBUG", "1") + usingEnvVar(t, "THRUSTER_LOG_REQUESTS", "0") c, err := NewConfig() require.NoError(t, err) @@ -145,6 +148,7 @@ func TestConfig_override_defaults_with_env_vars_using_prefix(t *testing.T) { assert.Equal(t, 5*time.Second, c.HttpReadTimeout) assert.Equal(t, false, c.XSendfileEnabled) assert.Equal(t, slog.LevelDebug, c.LogLevel) + assert.Equal(t, false, c.LogRequests) } func TestConfig_prefixed_variables_take_precedence_over_non_prefixed(t *testing.T) { diff --git a/internal/handler.go b/internal/handler.go index c9fbafa..16609cf 100644 --- a/internal/handler.go +++ b/internal/handler.go @@ -17,12 +17,15 @@ type HandlerOptions struct { xSendfileEnabled bool gzipCompressionEnabled bool forwardHeaders bool + logRequests bool } func NewHandler(options HandlerOptions) http.Handler { handler := NewProxyHandler(options.targetUrl, options.badGatewayPage, options.forwardHeaders) handler = NewCacheHandler(options.cache, options.maxCacheableResponseBody, handler) handler = NewSendfileHandler(options.xSendfileEnabled, handler) + handler = NewRequestStartMiddleware(handler) + if options.gzipCompressionEnabled { handler = gzhttp.GzipHandler(handler) } @@ -31,7 +34,9 @@ func NewHandler(options HandlerOptions) http.Handler { handler = http.MaxBytesHandler(handler, int64(options.maxRequestBody)) } - handler = NewLoggingMiddleware(slog.Default(), handler) + if options.logRequests { + handler = NewLoggingMiddleware(slog.Default(), handler) + } return handler } diff --git a/internal/handler_test.go b/internal/handler_test.go index 77d2584..a564a86 100644 --- a/internal/handler_test.go +++ b/internal/handler_test.go @@ -2,6 +2,7 @@ package internal import ( "bytes" + "fmt" "net/http" "net/http/httptest" "net/url" @@ -263,6 +264,43 @@ func TestHandlerXForwardedHeadersDropsExistingHeadersWhenForwardingNotEnabled(t h.ServeHTTP(w, r) } +func TestHandlerAddsXRequestStartHeader(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + header := r.Header.Get("X-Request-Start") + assert.NotEmpty(t, header, "X-Request-Start header should be present") + assert.Regexp(t, `^t=\d+$`, header, "X-Request-Start header should be in format t=msec") + })) + defer upstream.Close() + + h := NewHandler(handlerOptions(upstream.URL)) + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + h.ServeHTTP(w, r) + + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestHandlerAllowsFlushingTheResponseBody(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + fmt.Fprintf(w, "data: event\n\n") + w.(http.Flusher).Flush() + + })) + defer upstream.Close() + + h := NewHandler(handlerOptions(upstream.URL)) + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + h.ServeHTTP(w, r) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type")) + assert.True(t, w.Flushed) +} + // Helpers func handlerOptions(targetUrl string) HandlerOptions { @@ -276,5 +314,6 @@ func handlerOptions(targetUrl string) HandlerOptions { maxCacheableResponseBody: 1024, badGatewayPage: "", forwardHeaders: true, + logRequests: true, } } diff --git a/internal/logging_middleware.go b/internal/logging_middleware.go index 289ecdd..fa9c581 100644 --- a/internal/logging_middleware.go +++ b/internal/logging_middleware.go @@ -87,3 +87,10 @@ func (r *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { } return con, rw, err } + +func (r *responseWriter) Flush() { + flusher, ok := r.ResponseWriter.(http.Flusher) + if ok { + flusher.Flush() + } +} diff --git a/internal/request_start_middleware.go b/internal/request_start_middleware.go new file mode 100644 index 0000000..40f3a69 --- /dev/null +++ b/internal/request_start_middleware.go @@ -0,0 +1,17 @@ +package internal + +import ( + "fmt" + "net/http" + "time" +) + +func NewRequestStartMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Request-Start") == "" { + timestamp := time.Now().UnixMilli() + r.Header.Set("X-Request-Start", fmt.Sprintf("t=%d", timestamp)) + } + next.ServeHTTP(w, r) + }) +} \ No newline at end of file diff --git a/internal/request_start_middleware_test.go b/internal/request_start_middleware_test.go new file mode 100644 index 0000000..ca2a0b1 --- /dev/null +++ b/internal/request_start_middleware_test.go @@ -0,0 +1,52 @@ +package internal + +import ( + "net/http" + "net/http/httptest" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRequestStartMiddleware(t *testing.T) { + var capturedHeader string + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedHeader = r.Header.Get("X-Request-Start") + }) + + middleware := NewRequestStartMiddleware(nextHandler) + + before := time.Now().UnixMilli() + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + after := time.Now().UnixMilli() + + assert.NotEmpty(t, capturedHeader) + assert.Regexp(t, `^t=\d+$`, capturedHeader) + + timestampStr := capturedHeader[2:] + timestamp, err := strconv.ParseInt(timestampStr, 10, 64) + assert.NoError(t, err) + assert.GreaterOrEqual(t, timestamp, before) + assert.LessOrEqual(t, timestamp, after) +} + +func TestRequestStartMiddlewareDoesNotOverwriteExistingHeader(t *testing.T) { + existingHeader := "t=1234567890" + var capturedHeader string + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedHeader = r.Header.Get("X-Request-Start") + }) + + middleware := NewRequestStartMiddleware(nextHandler) + + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("X-Request-Start", existingHeader) + w := httptest.NewRecorder() + middleware.ServeHTTP(w, req) + + assert.Equal(t, existingHeader, capturedHeader) +} \ No newline at end of file diff --git a/internal/sendfile_handler.go b/internal/sendfile_handler.go index d6a7cf8..9c20547 100644 --- a/internal/sendfile_handler.go +++ b/internal/sendfile_handler.go @@ -79,6 +79,13 @@ func (w *sendfileWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { return hijacker.Hijack() } +func (w *sendfileWriter) Flush() { + flusher, ok := w.w.(http.Flusher) + if ok { + flusher.Flush() + } +} + func (w *sendfileWriter) sendingFilename() string { return w.w.Header().Get("X-Sendfile") } diff --git a/internal/service.go b/internal/service.go index 04fee25..e844783 100644 --- a/internal/service.go +++ b/internal/service.go @@ -27,6 +27,7 @@ func (s *Service) Run() int { maxRequestBody: s.config.MaxRequestBody, badGatewayPage: s.config.BadGatewayPage, forwardHeaders: s.config.ForwardHeaders, + logRequests: s.config.LogRequests, } handler := NewHandler(handlerOptions) diff --git a/internal/variant.go b/internal/variant.go index 3685075..f34e9e1 100644 --- a/internal/variant.go +++ b/internal/variant.go @@ -25,6 +25,7 @@ func (v *Variant) CacheKey() CacheKey { hash.Write([]byte(v.r.Method)) hash.Write([]byte(v.r.URL.Path)) hash.Write([]byte(v.r.URL.Query().Encode())) + hash.Write([]byte(v.r.Host)) for _, name := range v.headerNames { hash.Write([]byte(name + "=" + v.r.Header.Get(name))) diff --git a/lib/thruster/version.rb b/lib/thruster/version.rb index 0964a9d..4c71429 100644 --- a/lib/thruster/version.rb +++ b/lib/thruster/version.rb @@ -1,3 +1,3 @@ module Thruster - VERSION = "0.1.13" + VERSION = "0.1.15" end