HTTP reverse proxy for Tailscale

split server and handler

This is in preparation for adding OIDC.

+138 -88
+134 -77
main.go
··· 4 4 "context" 5 5 "crypto/tls" 6 6 "encoding/json" 7 + "errors" 7 8 "flag" 8 9 "fmt" 9 10 "log/slog" ··· 219 220 if err != nil { 220 221 return fmt.Errorf("upstream %s: parse backend URL: %w", upstream.Name, err) 221 222 } 223 + // TODO(sr) Instrument proxy.Transport 224 + proxy := &httputil.ReverseProxy{ 225 + Rewrite: func(req *httputil.ProxyRequest) { 226 + req.SetURL(backendURL) 227 + req.SetXForwarded() 228 + req.Out.Host = req.In.Host 229 + }, 230 + } 231 + proxy.ErrorHandler = func(w http.ResponseWriter, _ *http.Request, err error) { 232 + http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway) 233 + logger.Error("upstream error", lerr(err)) 234 + } 222 235 223 - srv := &http.Server{ 224 - TLSConfig: &tls.Config{GetCertificate: lc.GetCertificate}, 225 - Handler: promhttp.InstrumentHandlerInFlight(requestsInFlight.With(prometheus.Labels{"upstream": upstream.Name}), 226 - promhttp.InstrumentHandlerDuration(duration.MustCurryWith(prometheus.Labels{"upstream": upstream.Name}), 227 - promhttp.InstrumentHandlerCounter(requests.MustCurryWith(prometheus.Labels{"upstream": upstream.Name}), 228 - newReverseProxy(log, lc, backendURL)))), 236 + instrument := func(h http.Handler) http.Handler { 237 + return promhttp.InstrumentHandlerInFlight( 238 + requestsInFlight.With(prometheus.Labels{"upstream": upstream.Name}), 239 + promhttp.InstrumentHandlerDuration( 240 + duration.MustCurryWith(prometheus.Labels{"upstream": upstream.Name}), 241 + promhttp.InstrumentHandlerCounter( 242 + requests.MustCurryWith(prometheus.Labels{"upstream": upstream.Name}), 243 + h, 244 + ), 245 + ), 246 + ) 229 247 } 230 248 231 - g.Add(func() error { 232 - st, err := ts.Up(ctx) 233 - if err != nil { 234 - return fmt.Errorf("tailscale: wait for node %s to be ready: %w", upstream.Name, err) 235 - } 249 + { 250 + var srv *http.Server 251 + g.Add(func() error { 252 + st, err := ts.Up(ctx) 253 + if err != nil { 254 + return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err) 255 + } 236 256 237 - // register in service discovery when we're ready. 238 - targets[i] = target{name: upstream.Name, prometheus: upstream.Prometheus, magicDNS: st.Self.DNSName} 257 + srv = &http.Server{Handler: instrument(localTailnetHandler(log, lc, proxy))} 258 + ln, err := ts.Listen("tcp", ":80") 259 + if err != nil { 260 + return fmt.Errorf("tailscale: listen for %s on port 80: %w", upstream.Name, err) 261 + } 239 262 240 - ln, err := ts.Listen("tcp", ":80") 241 - if err != nil { 242 - return fmt.Errorf("tailscale: listen for %s on port 80: %w", upstream.Name, err) 243 - } 244 - return srv.Serve(ln) 245 - }, func(_ error) { 246 - if err := srv.Close(); err != nil { 247 - log.Error("server shutdown", lerr(err)) 248 - } 249 - cancel() 250 - }) 251 - g.Add(func() error { 252 - _, err := ts.Up(ctx) 253 - if err != nil { 254 - return fmt.Errorf("tailscale: wait for node %s to be ready: %w", upstream.Name, err) 255 - } 263 + // register in service discovery when we're ready. 264 + targets[i] = target{name: upstream.Name, prometheus: upstream.Prometheus, magicDNS: st.Self.DNSName} 256 265 257 - if upstream.Funnel { 258 - ln, err := ts.ListenFunnel("tcp", ":443") 266 + return srv.Serve(ln) 267 + }, func(_ error) { 268 + if srv != nil { 269 + if err := srv.Close(); err != nil { 270 + log.Error("server shutdown", lerr(err)) 271 + } 272 + } 273 + cancel() 274 + }) 275 + } 276 + { 277 + var srv *http.Server 278 + g.Add(func() error { 279 + _, err := ts.Up(ctx) 259 280 if err != nil { 260 - return fmt.Errorf("tailscale: funnel for %s on port 443: %w", upstream.Name, err) 281 + return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err) 282 + } 283 + srv = &http.Server{ 284 + TLSConfig: &tls.Config{GetCertificate: lc.GetCertificate}, 285 + Handler: instrument(localTailnetTLSHandler(log, lc, proxy)), 261 286 } 262 - return srv.Serve(ln) 263 - } 264 287 265 - ln, err := ts.Listen("tcp", ":443") 266 - if err != nil { 267 - return fmt.Errorf("tailscale: listen for %s on port 443: %w", upstream.Name, err) 268 - } 269 - return srv.ServeTLS(ln, "", "") 270 - }, func(_ error) { 271 - if err := srv.Close(); err != nil { 272 - log.Error("TLS server shutdown", lerr(err)) 288 + ln, err := ts.Listen("tcp", ":443") 289 + if err != nil { 290 + return fmt.Errorf("tailscale: listen for %s on port 443: %w", upstream.Name, err) 291 + } 292 + return srv.ServeTLS(ln, "", "") 293 + }, func(_ error) { 294 + if srv != nil { 295 + if err := srv.Close(); err != nil { 296 + log.Error("TLS server shutdown", lerr(err)) 297 + } 298 + } 299 + cancel() 300 + }) 301 + } 302 + if upstream.Funnel { 303 + { 304 + var srv *http.Server 305 + g.Add(func() error { 306 + _, err := ts.Up(ctx) 307 + if err != nil { 308 + return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err) 309 + } 310 + srv = &http.Server{ 311 + Handler: instrument(insecureFunnelHandler(log, lc, proxy)), 312 + } 313 + 314 + ln, err := ts.ListenFunnel("tcp", ":443", tsnet.FunnelOnly()) 315 + if err != nil { 316 + return fmt.Errorf("tailscale: funnel for %s on port 443: %w", upstream.Name, err) 317 + } 318 + return srv.Serve(ln) 319 + }, func(_ error) { 320 + if srv != nil { 321 + if err := srv.Close(); err != nil { 322 + log.Error("TLS server shutdown", lerr(err)) 323 + } 324 + } 325 + cancel() 326 + }) 273 327 } 274 - cancel() 275 - }) 328 + } 276 329 } 277 330 278 331 return g.Run() 279 332 } 280 333 281 - type tailscaleLocalClient interface { 282 - WhoIs(context.Context, string) (*apitype.WhoIsResponse, error) 283 - } 284 - 285 - func newReverseProxy(logger *slog.Logger, lc tailscaleLocalClient, url *url.URL) http.HandlerFunc { 286 - // TODO(sr) Instrument proxy.Transport 287 - rproxy := &httputil.ReverseProxy{ 288 - Rewrite: func(req *httputil.ProxyRequest) { 289 - req.SetURL(url) 290 - req.SetXForwarded() 291 - req.Out.Host = req.In.Host 292 - }, 293 - } 294 - rproxy.ErrorHandler = func(w http.ResponseWriter, _ *http.Request, err error) { 295 - http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway) 296 - logger.Error("upstream error", lerr(err)) 297 - } 298 - 334 + // localTailnetHandler serves plain-HTTP on the local tailnet. 335 + func localTailnetHandler(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler { 299 336 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 300 - whois, err := lc.WhoIs(r.Context(), r.RemoteAddr) 337 + whois, err := tsWhoIs(lc, r) 301 338 if err != nil { 302 339 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 303 340 logger.Error("tailscale whois", lerr(err)) 304 341 return 305 342 } 306 343 307 - if whois.Node == nil { 308 - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 309 - logger.Error("tailscale whois", slog.String("err", "node missing")) 310 - return 311 - } 312 - 313 - if whois.UserProfile == nil { 314 - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 315 - logger.Error("tailscale whois", slog.String("err", "user profile missing")) 316 - return 317 - } 318 - 319 344 // Proxy requests from tagged nodes as is. 320 345 if whois.Node.IsTagged() { 321 - rproxy.ServeHTTP(w, r) 346 + next.ServeHTTP(w, r) 322 347 return 323 348 } 324 349 325 350 req := r.Clone(r.Context()) 326 351 req.Header.Set("X-Webauth-User", whois.UserProfile.LoginName) 327 352 req.Header.Set("X-Webauth-Name", whois.UserProfile.DisplayName) 328 - rproxy.ServeHTTP(w, req) 353 + next.ServeHTTP(w, req) 329 354 }) 355 + } 356 + 357 + // localTailnetTLSHandler serves HTTPS on the local tailnet. 358 + func localTailnetTLSHandler(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler { 359 + return localTailnetHandler(logger, lc, next) 360 + } 361 + 362 + // insecureFunnelHandler handles HTTPS requests coming from Tailscale Funnel nodes. 363 + // This is marked insecure because the upstream is exposed to the public Internet. 364 + // The upstream is responsible for implementing authentication. 365 + func insecureFunnelHandler(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler { 366 + return localTailnetHandler(logger, lc, next) 367 + } 368 + 369 + type tailscaleLocalClient interface { 370 + WhoIs(context.Context, string) (*apitype.WhoIsResponse, error) 371 + } 372 + 373 + func tsWhoIs(lc tailscaleLocalClient, r *http.Request) (*apitype.WhoIsResponse, error) { 374 + whois, err := lc.WhoIs(r.Context(), r.RemoteAddr) 375 + if err != nil { 376 + return nil, fmt.Errorf("tailscale whois: %w", err) 377 + } 378 + 379 + if whois.Node == nil { 380 + return nil, errors.New("tailscale whois: node missing") 381 + } 382 + 383 + if whois.UserProfile == nil { 384 + return nil, errors.New("tailscale whois: user profile missing") 385 + } 386 + return whois, nil 330 387 } 331 388 332 389 func serveDiscovery(self string, targets []target) http.Handler {
+4 -11
tsproxy_test.go
··· 5 5 "errors" 6 6 "fmt" 7 7 "io" 8 - "log" 9 8 "log/slog" 10 9 "net/http" 11 10 "net/http/httptest" 12 - "net/url" 13 11 "testing" 14 12 15 13 "github.com/google/go-cmp/cmp" ··· 27 25 return c.whois(ctx, remoteAddr) 28 26 } 29 27 30 - func TestReverseProxy(t *testing.T) { 28 + func TestLocalTailnetHandler(t *testing.T) { 31 29 t.Parallel() 32 30 33 31 for _, tc := range []struct { ··· 80 78 t.Run(tc.name, func(t *testing.T) { 81 79 t.Parallel() 82 80 lc := &fakeLocalClient{whois: tc.whois} 83 - be := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 81 + be := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 84 82 for k, v := range r.Header { 85 83 w.Header().Set(k, v[0]) 86 84 } 87 85 fmt.Fprintln(w, "Hi from the backend.") 88 - })) 89 - defer be.Close() 90 - beURL, err := url.Parse(be.URL) 91 - if err != nil { 92 - log.Fatal(err) 93 - } 94 - px := httptest.NewServer(newReverseProxy(slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), lc, beURL)) 86 + }) 87 + px := httptest.NewServer(localTailnetHandler(slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), lc, be)) 95 88 defer px.Close() 96 89 97 90 resp, err := http.Get(px.URL)