HTTP reverse proxy for Tailscale

oidc and basic auth #1

open opened by sr.aux1.dev targeting main from oidc
Labels

None yet.

Participants 1
AT URI
at://did:plc:tshg7t4nzrrc5kgy6j5q55he/sh.tangled.repo.pull/3lqrsvfqrlw22
+627 -212
Diff #0
+25 -9
go.mod
··· 4 4 5 5 6 6 7 - 7 + require ( 8 8 github.com/google/go-cmp v0.7.0 9 + github.com/lstoll/oidc v1.0.0-beta.4.0.20250106123456-6ffce62670fe 9 10 github.com/oklog/run v1.1.0 10 11 github.com/prometheus/client_golang v1.21.1 12 + github.com/prometheus/common v0.63.0 13 + github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a 11 14 tailscale.com v1.82.0 12 15 ) 13 16 ··· 51 54 52 55 53 56 57 + github.com/klauspost/compress v1.17.11 // indirect 58 + github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a // indirect 59 + github.com/kylelemons/godebug v1.1.0 // indirect 60 + github.com/mdlayher/genetlink v1.3.2 // indirect 61 + github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect 62 + github.com/mdlayher/sdnotify v1.0.0 // indirect 54 63 55 64 56 65 57 66 58 - 59 - 60 - 61 - 62 - 63 - 64 67 github.com/pierrec/lz4/v4 v4.1.21 // indirect 65 68 github.com/prometheus-community/pro-bing v0.4.0 // indirect 66 69 github.com/prometheus/client_model v0.6.1 // indirect 67 - github.com/prometheus/common v0.63.0 // indirect 68 70 github.com/prometheus/procfs v0.15.1 // indirect 69 71 github.com/safchain/ethtool v0.3.0 // indirect 70 72 github.com/tailscale/certstore v0.1.1-0.20231202035212-d3fa0460f47e // indirect 71 73 github.com/tailscale/go-winio v0.0.0-20231025203758-c4f33415bf55 // indirect 72 74 github.com/tailscale/goupnp v1.0.1-0.20210804011211-c64d0f06ea05 // indirect 73 - github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a // indirect 74 75 github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7 // indirect 75 76 github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc // indirect 76 77 github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 // indirect 78 + github.com/tailscale/wireguard-go v0.0.0-20250107165329-0b8b35511f19 // indirect 79 + github.com/tink-crypto/tink-go/v2 v2.2.0 // indirect 80 + github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 // indirect 81 + github.com/vishvananda/netns v0.0.4 // indirect 82 + github.com/x448/float16 v0.8.4 // indirect 83 + 84 + 85 + 86 + golang.org/x/exp v0.0.0-20250210185358-939b2ce775ac // indirect 87 + golang.org/x/mod v0.23.0 // indirect 88 + golang.org/x/net v0.36.0 // indirect 89 + golang.org/x/oauth2 v0.26.0 // indirect 90 + golang.org/x/sync v0.11.0 // indirect 91 + golang.org/x/sys v0.30.0 // indirect 92 + golang.org/x/term v0.29.0 // indirect
+303 -101
main.go
··· 16 16 17 17 18 18 19 + "strings" 20 + "syscall" 19 21 20 - 21 - 22 - 23 - 24 - 22 + "github.com/lstoll/oidc" 23 + "github.com/lstoll/oidc/middleware" 24 + "github.com/oklog/run" 25 + "github.com/prometheus/client_golang/prometheus" 26 + versioncollector "github.com/prometheus/client_golang/prometheus/collectors/version" 25 27 "github.com/prometheus/client_golang/prometheus/promauto" 26 28 "github.com/prometheus/client_golang/prometheus/promhttp" 27 29 "github.com/prometheus/common/version" 30 + "github.com/tailscale/hujson" 28 31 "tailscale.com/client/local" 29 32 "tailscale.com/client/tailscale/apitype" 33 + "tailscale.com/ipn/ipnstate" 30 34 "tailscale.com/tsnet" 35 + tslogger "tailscale.com/types/logger" 36 + ) 31 37 32 38 33 39 ··· 56 62 57 63 58 64 59 - 60 - 61 65 ) 62 66 ) 63 67 64 - type upstreamFlag []upstream 65 - 66 - func (f *upstreamFlag) String() string { 67 - return fmt.Sprintf("%+v", *f) 68 + type upstream struct { 69 + Name string 70 + Backend string 71 + Prometheus bool 72 + Funnel *funnelConfig `json:"funnel,omitempty"` 68 73 } 69 74 70 - func (f *upstreamFlag) Set(val string) error { 71 - up, err := parseUpstreamFlag(val) 72 - if err != nil { 73 - return err 74 - } 75 - *f = append(*f, up) 76 - return nil 75 + type funnelConfig struct { 76 + Insecure bool 77 + Issuer string 78 + ClientID string 79 + ClientSecret string 77 80 } 78 81 79 - type upstream struct { 80 - name string 81 - backend *url.URL 82 - prometheus bool 83 - funnel bool 84 - } 85 - 86 82 type target struct { 87 83 88 84 89 85 prometheus bool 90 86 } 91 87 92 - func parseUpstreamFlag(fval string) (upstream, error) { 93 - k, v, ok := strings.Cut(fval, "=") 94 - if !ok { 95 - return upstream{}, errors.New("format: name=http://backend") 96 - } 97 - val := strings.Split(v, ";") 98 - be, err := url.Parse(val[0]) 99 - if err != nil { 100 - return upstream{}, err 101 - } 102 - up := upstream{name: k, backend: be} 103 - if len(val) > 1 { 104 - for _, opt := range val[1:] { 105 - switch opt { 106 - case "prometheus": 107 - up.prometheus = true 108 - case "funnel": 109 - up.funnel = true 110 - default: 111 - return upstream{}, fmt.Errorf("unsupported option: %v", opt) 112 - } 113 - } 114 - } 115 - return up, nil 116 - } 117 - 118 88 func main() { 119 89 if err := tsproxy(context.Background()); err != nil { 120 90 fmt.Fprintf(os.Stderr, "tsproxy: %v\n", err) ··· 124 94 125 95 func tsproxy(ctx context.Context) error { 126 96 var ( 127 - state = flag.String("state", "", "Optional directory for storing Tailscale state.") 128 - tslog = flag.Bool("tslog", false, "If true, log Tailscale output.") 129 - port = flag.Int("port", 32019, "HTTP port for metrics and service discovery.") 130 - ver = flag.Bool("version", false, "print the version and exit") 97 + state = flag.String("state", "", "Optional directory for storing Tailscale state.") 98 + tslog = flag.Bool("tslog", false, "If true, log Tailscale output.") 99 + port = flag.Int("port", 32019, "HTTP port for metrics and service discovery.") 100 + ver = flag.Bool("version", false, "print the version and exit") 101 + upfile = flag.String("upstream", "", "path to upstreams config file") 131 102 ) 132 - var upstreams upstreamFlag 133 - flag.Var(&upstreams, "upstream", "Repeated for each upstream. Format: name=http://backend:8000") 134 103 flag.Parse() 135 104 136 105 if *ver { ··· 138 107 os.Exit(0) 139 108 } 140 109 141 - if len(upstreams) == 0 { 110 + if *upfile == "" { 142 111 return fmt.Errorf("required flag missing: upstream") 143 112 } 113 + 114 + in, err := os.ReadFile(*upfile) 115 + if err != nil { 116 + return err 117 + } 118 + inJSON, err := hujson.Standardize(in) 119 + if err != nil { 120 + return fmt.Errorf("hujson: %w", err) 121 + } 122 + var upstreams []upstream 123 + if err := json.Unmarshal(inJSON, &upstreams); err != nil { 124 + return fmt.Errorf("json: %w", err) 125 + } 126 + if len(upstreams) == 0 { 127 + return fmt.Errorf("file does not contain any upstreams: %s", *upfile) 128 + } 129 + 144 130 if *state == "" { 145 131 v, err := os.UserCacheDir() 146 132 if err != nil { ··· 212 198 213 199 214 200 201 + } 215 202 203 + for i, upstream := range upstreams { 204 + log := logger.With(slog.String("upstream", upstream.Name)) 216 205 217 - 218 - 219 - i := i 220 - upstream := upstream 221 - 222 - log := logger.With(slog.String("upstream", upstream.name)) 223 - 224 206 ts := &tsnet.Server{ 225 - Hostname: upstream.name, 226 - Dir: filepath.Join(*state, "tailscale-"+upstream.name), 207 + Hostname: upstream.Name, 208 + Dir: filepath.Join(*state, "tailscale-"+upstream.Name), 227 209 RunWebClient: true, 228 210 } 229 211 defer ts.Close() ··· 242 224 243 225 lc, err := ts.LocalClient() 244 226 if err != nil { 245 - return fmt.Errorf("tailscale: get local client for %s: %w", upstream.name, err) 227 + return fmt.Errorf("tailscale: get local client for %s: %w", upstream.Name, err) 246 228 } 247 229 248 - srv := &http.Server{ 249 - TLSConfig: &tls.Config{GetCertificate: lc.GetCertificate}, 250 - Handler: promhttp.InstrumentHandlerInFlight(requestsInFlight.With(prometheus.Labels{"upstream": upstream.name}), 251 - promhttp.InstrumentHandlerDuration(duration.MustCurryWith(prometheus.Labels{"upstream": upstream.name}), 252 - promhttp.InstrumentHandlerCounter(requests.MustCurryWith(prometheus.Labels{"upstream": upstream.name}), 253 - newReverseProxy(log, lc, upstream.backend)))), 230 + backendURL, err := url.Parse(upstream.Backend) 231 + if err != nil { 232 + return fmt.Errorf("upstream %s: parse backend URL: %w", upstream.Name, err) 254 233 } 234 + // TODO(sr) Instrument proxy.Transport 235 + proxy := &httputil.ReverseProxy{ 236 + Rewrite: func(req *httputil.ProxyRequest) { 237 + req.SetURL(backendURL) 238 + req.SetXForwarded() 239 + req.Out.Host = req.In.Host 240 + }, 241 + } 242 + proxy.ErrorHandler = func(w http.ResponseWriter, _ *http.Request, err error) { 243 + http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway) 244 + logger.Error("upstream error", lerr(err)) 245 + } 255 246 256 - g.Add(func() error { 257 - st, err := ts.Up(ctx) 258 - if err != nil { 259 - return fmt.Errorf("tailscale: wait for node %s to be ready: %w", upstream.name, err) 247 + instrument := func(h http.Handler) http.Handler { 248 + return promhttp.InstrumentHandlerInFlight( 249 + requestsInFlight.With(prometheus.Labels{"upstream": upstream.Name}), 250 + promhttp.InstrumentHandlerDuration( 251 + duration.MustCurryWith(prometheus.Labels{"upstream": upstream.Name}), 252 + promhttp.InstrumentHandlerCounter( 253 + requests.MustCurryWith(prometheus.Labels{"upstream": upstream.Name}), 254 + h, 255 + ), 256 + ), 257 + ) 258 + } 259 + 260 + { 261 + var srv *http.Server 262 + g.Add(func() error { 263 + st, err := ts.Up(ctx) 264 + if err != nil { 265 + return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err) 266 + } 267 + 268 + srv = &http.Server{Handler: instrument(localTailnetHandler(log, lc, proxy))} 269 + ln, err := ts.Listen("tcp", ":80") 270 + if err != nil { 271 + return fmt.Errorf("tailscale: listen for %s on port 80: %w", upstream.Name, err) 272 + } 273 + 274 + // register in service discovery when we're ready. 275 + targets[i] = target{name: upstream.Name, prometheus: upstream.Prometheus, magicDNS: st.Self.DNSName} 276 + 277 + return srv.Serve(ln) 278 + }, func(_ error) { 279 + if srv != nil { 280 + if err := srv.Close(); err != nil { 281 + log.Error("server shutdown", lerr(err)) 282 + } 283 + } 284 + cancel() 285 + }) 286 + } 287 + { 288 + var srv *http.Server 289 + g.Add(func() error { 290 + _, err := ts.Up(ctx) 291 + if err != nil { 292 + return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err) 293 + } 294 + srv = &http.Server{ 295 + TLSConfig: &tls.Config{GetCertificate: lc.GetCertificate}, 296 + Handler: instrument(localTailnetTLSHandler(log, lc, proxy)), 297 + } 298 + 299 + ln, err := ts.Listen("tcp", ":443") 300 + if err != nil { 301 + return fmt.Errorf("tailscale: listen for %s on port 443: %w", upstream.Name, err) 302 + } 303 + return srv.ServeTLS(ln, "", "") 304 + }, func(_ error) { 305 + if srv != nil { 306 + if err := srv.Close(); err != nil { 307 + log.Error("TLS server shutdown", lerr(err)) 308 + } 309 + } 310 + cancel() 311 + }) 312 + } 313 + if funnel := upstream.Funnel; funnel != nil { 314 + if !funnel.Insecure && funnel.Issuer == "" { 315 + return fmt.Errorf("upstream %s: funnel must set issuer or insecure", upstream.Name) 260 316 } 317 + { 318 + var srv *http.Server 319 + g.Add(func() error { 320 + _, err := ts.Up(ctx) 321 + if err != nil { 322 + return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err) 323 + } 261 324 262 - // register in service discovery when we're ready. 263 - targets[i] = target{name: upstream.name, prometheus: upstream.prometheus, magicDNS: st.Self.DNSName} 325 + srv = &http.Server{} 326 + if funnel.Issuer != "" { 327 + handler, err := oidcFunnelHandler(ctx, log, lc, funnel, proxy) 328 + if err != nil { 329 + return fmt.Errorf("oidc: %w", err) 330 + } 331 + srv.Handler = handler 332 + } else if funnel.Insecure { 333 + srv.Handler = insecureFunnelHandler(log, lc, proxy) 334 + } else { 335 + panic("funnel misconfigured") 336 + } 337 + srv.Handler = instrument(srv.Handler) 264 338 265 - ln, err := ts.Listen("tcp", ":80") 266 - if err != nil { 267 - return fmt.Errorf("tailscale: listen for %s on port 80: %w", upstream.name, err) 339 + ln, err := ts.ListenFunnel("tcp", ":443", tsnet.FunnelOnly()) 340 + if err != nil { 341 + return fmt.Errorf("tailscale: funnel for %s on port 443: %w", upstream.Name, err) 342 + } 343 + return srv.Serve(ln) 344 + }, func(_ error) { 345 + if srv != nil { 346 + if err := srv.Close(); err != nil { 347 + log.Error("TLS server shutdown", lerr(err)) 348 + } 349 + } 350 + cancel() 351 + }) 268 352 } 269 - return srv.Serve(ln) 270 - }, func(_ error) { 353 + } 354 + } 271 355 356 + return g.Run() 357 + } 272 358 359 + // localTailnetHandler serves plain-HTTP on the local tailnet. 360 + func localTailnetHandler(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler { 361 + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 362 + whois, err := tsWhoIs(lc, r) 363 + if err != nil { 364 + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 365 + logger.ErrorContext(r.Context(), "tailscale whois", lerr(err)) 366 + return 367 + } 273 368 369 + // Proxy requests from tagged nodes as is. 370 + if whois.Node.IsTagged() { 371 + next.ServeHTTP(w, r) 372 + return 373 + } 274 374 375 + req := r.Clone(r.Context()) 376 + req.Header.Set("X-Webauth-User", whois.UserProfile.LoginName) 377 + req.Header.Set("X-Webauth-Name", whois.UserProfile.DisplayName) 378 + next.ServeHTTP(w, req) 379 + }) 380 + } 275 381 276 - g.Add(func() error { 277 - _, err := ts.Up(ctx) 382 + // localTailnetTLSHandler serves HTTPS on the local tailnet. 383 + func localTailnetTLSHandler(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler { 384 + var ( 385 + handler = localTailnetHandler(logger, lc, next) 386 + dnsName string 387 + ) 388 + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 389 + if r.TLS == nil { 390 + panic("tailnet handler wants tls") 391 + } 392 + 393 + if dnsName == "" { 394 + st, err := lc.StatusWithoutPeers(r.Context()) 278 395 if err != nil { 279 - return fmt.Errorf("tailscale: wait for node %s to be ready: %w", upstream.name, err) 396 + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 397 + logger.ErrorContext(r.Context(), "tailscale status", slog.Any("err", err)) 398 + return 280 399 } 400 + dnsName = strings.TrimSuffix(st.Self.DNSName, ".") 401 + } 281 402 282 - if upstream.funnel { 283 - ln, err := ts.ListenFunnel("tcp", ":443") 284 - if err != nil { 285 - return fmt.Errorf("tailscale: funnel for %s on port 443: %w", upstream.name, err) 286 - } 287 - return srv.Serve(ln) 288 - } 403 + if strings.TrimSuffix(r.Host, ".") != dnsName { 404 + http.Redirect(w, r, fmt.Sprintf("https://%s%s", dnsName, r.RequestURI), http.StatusPermanentRedirect) 405 + return 406 + } 289 407 290 - ln, err := ts.Listen("tcp", ":443") 291 - if err != nil { 292 - return fmt.Errorf("tailscale: listen for %s on port 443: %w", upstream.name, err) 293 - } 294 - return srv.ServeTLS(ln, "", "") 295 - }, func(_ error) { 408 + handler.ServeHTTP(w, r) 409 + }) 410 + } 411 + 412 + // insecureFunnelHandler handles HTTPS requests coming from Tailscale Funnel nodes. 413 + // This is marked insecure because the upstream is exposed to the public Internet. 414 + // The upstream is responsible for implementing authentication. 415 + func insecureFunnelHandler(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler { 416 + return localTailnetTLSHandler(logger, lc, next) 417 + } 418 + 419 + // oidcFunnelHandlers serves Funnel requests, requiring authentication via the configured OIDC issuer. 420 + func oidcFunnelHandler(ctx context.Context, logger *slog.Logger, lc tailscaleLocalClient, cfg *funnelConfig, next http.Handler) (http.Handler, error) { 421 + st, err := lc.StatusWithoutPeers(ctx) 422 + if err != nil { 423 + return nil, fmt.Errorf("tailscale status: %w", err) 424 + } 425 + 426 + redir := &url.URL{Scheme: "https", Path: ".oidc-callback"} 427 + redir.Host = strings.TrimSuffix(st.Self.DNSName, ".") 428 + 429 + wrapper, err := middleware.NewFromDiscovery(ctx, nil, cfg.Issuer, cfg.ClientID, cfg.ClientSecret, redir.String()) 430 + if err != nil { 431 + return nil, fmt.Errorf("oidc middleware: %w", err) 432 + } 433 + wrapper.OAuth2Config.Scopes = append(wrapper.OAuth2Config.Scopes, oidc.ScopeProfile) 434 + 435 + return wrapper.Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 436 + if r.TLS == nil { 437 + panic("oidc handler wants tls") 438 + } 439 + 440 + _, err := tsWhoIs(lc, r) 441 + if err != nil { 442 + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 443 + logger.ErrorContext(r.Context(), "tailscale whois", lerr(err)) 444 + return 445 + } 446 + 447 + tok := middleware.IDJWTFromContext(r.Context()) 448 + if tok == nil { 449 + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 450 + logger.ErrorContext(r.Context(), "jwt token missing") 451 + return 452 + } 453 + email, err := tok.StringClaim("email") 454 + if err != nil { 455 + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 456 + logger.ErrorContext(r.Context(), "claim missing", slog.String("claim", "email")) 457 + return 458 + } 459 + name, err := tok.StringClaim("name") 460 + if err != nil { 461 + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 462 + logger.ErrorContext(r.Context(), "claim missing", slog.String("claim", "name")) 463 + return 464 + } 465 + 466 + req := r.Clone(r.Context()) 467 + req.Header.Set("X-Webauth-User", email) 468 + req.Header.Set("X-Webauth-Name", name) 469 + 470 + next.ServeHTTP(w, r) 471 + })), nil 472 + } 473 + 474 + type tailscaleLocalClient interface { 475 + WhoIs(context.Context, string) (*apitype.WhoIsResponse, error) 476 + StatusWithoutPeers(context.Context) (*ipnstate.Status, error) 477 + } 478 + 479 + func tsWhoIs(lc tailscaleLocalClient, r *http.Request) (*apitype.WhoIsResponse, error) { 480 + whois, err := lc.WhoIs(r.Context(), r.RemoteAddr) 481 + if err != nil { 482 + return nil, fmt.Errorf("tailscale whois: %w", err) 483 + } 484 + 485 + if whois.Node == nil { 486 + return nil, errors.New("tailscale whois: node missing") 487 + } 488 + 489 + if whois.UserProfile == nil { 490 + return nil, errors.New("tailscale whois: user profile missing") 491 + } 492 + return whois, nil 493 + } 494 + 495 + func serveDiscovery(self string, targets []target) http.Handler { 496 + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 497 + var tgs []string
+6
go.sum
··· 117 117 github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 118 118 github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= 119 119 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= 120 + github.com/lstoll/oidc v1.0.0-beta.4.0.20250106123456-6ffce62670fe h1:QBlUtM+Rv9P+W3k9C6+xLgpssfxcKd8Ir+pvNM7E23Y= 121 + github.com/lstoll/oidc v1.0.0-beta.4.0.20250106123456-6ffce62670fe/go.mod h1:H1Y2Ektfl9aWzSHYT1qf6lXpE9mdil6ZavkI/5+N5Qg= 120 122 github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= 121 123 github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= 122 124 github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg= ··· 183 185 github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e/go.mod h1:orPd6JZXXRyuDusYilywte7k094d7dycXXU5YnWsrwg= 184 186 github.com/tc-hib/winres v0.2.1 h1:YDE0FiP0VmtRaDn7+aaChp1KiF4owBiJa5l964l5ujA= 185 187 github.com/tc-hib/winres v0.2.1/go.mod h1:C/JaNhH3KBvhNKVbvdlDWkbMDO9H4fKKDaN7/07SSuk= 188 + github.com/tink-crypto/tink-go/v2 v2.2.0 h1:L2Da0F2Udh2agtKztdr69mV/KpnY3/lGTkMgLTVIXlA= 189 + github.com/tink-crypto/tink-go/v2 v2.2.0/go.mod h1:JJ6PomeNPF3cJpfWC0lgyTES6zpJILkAX0cJNwlS3xU= 186 190 github.com/u-root/u-root v0.12.0 h1:K0AuBFriwr0w/PGS3HawiAw89e3+MU7ks80GpghAsNs= 187 191 github.com/u-root/u-root v0.12.0/go.mod h1:FYjTOh4IkIZHhjsd17lb8nYW6udgXdJhG1c0r6u0arI= 188 192 github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 h1:pyC9PaHYZFgEKFdlp3G8RaCKgVpHZnecvArXvPXcFkM= ··· 208 212 golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= 209 213 golang.org/x/net v0.36.0 h1:vWF2fRbw4qslQsQzgFqZff+BItCvGFQqKzKIzx1rmoA= 210 214 golang.org/x/net v0.36.0/go.mod h1:bFmbeoIPfrw4sMHNhb4J9f6+tPziuGjq7Jk/38fxi1I= 215 + golang.org/x/oauth2 v0.26.0 h1:afQXWNNaeC4nvZ0Ed9XvCCzXM6UHJG7iCg0W4fPqSBE= 216 + golang.org/x/oauth2 v0.26.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= 211 217 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 212 218 golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= 213 219 golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
+222 -102
tsproxy.go
··· 1 1 2 2 3 + import ( 4 + "context" 5 + "crypto/subtle" 6 + "crypto/tls" 7 + "encoding/json" 8 + "errors" 3 9 4 10 5 11 12 + "net" 13 + "net/http" 14 + "net/http/httputil" 15 + "net/netip" 16 + "net/url" 17 + "os" 18 + "path/filepath" 6 19 7 20 8 21 ··· 16 29 17 30 18 31 19 - 20 - 21 - 22 - 23 - 24 - 25 - 26 - 27 - 28 - 29 - 30 32 "github.com/tailscale/hujson" 31 33 "tailscale.com/client/local" 32 34 "tailscale.com/client/tailscale/apitype" 33 - "tailscale.com/ipn/ipnstate" 35 + "tailscale.com/ipn" 34 36 "tailscale.com/tsnet" 35 37 tslogger "tailscale.com/types/logger" 36 38 ) 37 39 40 + // ctxConn is a key to look up a net.Conn stored in an HTTP request's context. 41 + type ctxConn struct{} 38 42 43 + var ( 44 + requestsInFlight = promauto.NewGaugeVec( 45 + prometheus.GaugeOpts{ 39 46 40 47 41 48 ··· 64 71 65 72 66 73 67 - 68 - 69 74 Name string 70 75 Backend string 71 76 Prometheus bool 72 - Funnel *funnelConfig `json:"funnel,omitempty"` 77 + Funnel *funnelConfig 73 78 } 74 79 75 80 type funnelConfig struct { 76 81 82 + Issuer string 83 + ClientID string 84 + ClientSecret string 85 + User string 86 + Password string 87 + IP []string 88 + } 77 89 90 + type target struct { 78 91 79 92 80 93 ··· 233 246 234 247 235 248 249 + } 250 + proxy.ErrorHandler = func(w http.ResponseWriter, _ *http.Request, err error) { 251 + http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway) 252 + log.Error("upstream error", lerr(err)) 253 + } 236 254 255 + instrument := func(h http.Handler) http.Handler { 237 256 238 257 239 258 ··· 251 270 252 271 253 272 254 - 255 - 256 - 257 - 258 - 259 - 260 - 261 - 262 - 263 - 264 - 265 273 return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err) 266 274 } 267 275 268 - srv = &http.Server{Handler: instrument(localTailnetHandler(log, lc, proxy))} 276 + srv = &http.Server{Handler: instrument(redirect(st.Self.DNSName, false, tailnet(log, lc, proxy)))} 269 277 ln, err := ts.Listen("tcp", ":80") 270 278 if err != nil { 271 279 return fmt.Errorf("tailscale: listen for %s on port 80: %w", upstream.Name, err) ··· 287 295 { 288 296 var srv *http.Server 289 297 g.Add(func() error { 290 - _, err := ts.Up(ctx) 298 + st, err := ts.Up(ctx) 291 299 if err != nil { 292 300 return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err) 293 301 } 302 + 294 303 srv = &http.Server{ 295 304 TLSConfig: &tls.Config{GetCertificate: lc.GetCertificate}, 296 - Handler: instrument(localTailnetTLSHandler(log, lc, proxy)), 305 + Handler: instrument(redirect(st.Self.DNSName, true, tailnet(log, lc, proxy))), 297 306 } 298 307 299 308 ln, err := ts.Listen("tcp", ":443") ··· 304 313 }, func(_ error) { 305 314 if srv != nil { 306 315 if err := srv.Close(); err != nil { 307 - log.Error("TLS server shutdown", lerr(err)) 316 + log.Error("server shutdown", lerr(err)) 308 317 } 309 318 } 310 319 cancel() 311 320 }) 312 321 } 313 322 if funnel := upstream.Funnel; funnel != nil { 314 - if !funnel.Insecure && funnel.Issuer == "" { 315 - return fmt.Errorf("upstream %s: funnel must set issuer or insecure", upstream.Name) 316 - } 317 323 { 318 324 var srv *http.Server 319 325 g.Add(func() error { 320 - _, err := ts.Up(ctx) 326 + st, err := ts.Up(ctx) 321 327 if err != nil { 322 328 return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err) 323 329 } 324 330 325 - srv = &http.Server{} 326 - if funnel.Issuer != "" { 327 - handler, err := oidcFunnelHandler(ctx, log, lc, funnel, proxy) 331 + var handler http.Handler 332 + switch { 333 + case funnel.Insecure: 334 + handler = insecureFunnel(log, lc, proxy) 335 + case funnel.Issuer != "": 336 + redir := &url.URL{Scheme: "https", Host: strings.TrimSuffix(st.Self.DNSName, "."), Path: ".oidc-callback"} 337 + wrapper, err := middleware.NewFromDiscovery(ctx, nil, funnel.Issuer, funnel.ClientID, funnel.ClientSecret, redir.String()) 328 338 if err != nil { 329 - return fmt.Errorf("oidc: %w", err) 339 + return fmt.Errorf("oidc middleware for %s: %w", upstream.Name, err) 330 340 } 331 - srv.Handler = handler 332 - } else if funnel.Insecure { 333 - srv.Handler = insecureFunnelHandler(log, lc, proxy) 334 - } else { 335 - panic("funnel misconfigured") 341 + wrapper.OAuth2Config.Scopes = append(wrapper.OAuth2Config.Scopes, oidc.ScopeProfile) 342 + 343 + handler = wrapper.Wrap(oidcFunnel(log, lc, proxy)) 344 + case funnel.User != "": 345 + handler = insecureFunnel(log, lc, basicAuth(log, funnel.User, funnel.Password, proxy)) 346 + default: 347 + return fmt.Errorf("upstream %s must set funnel.insecure or funnel.issuer", upstream.Name) 336 348 } 337 - srv.Handler = instrument(srv.Handler) 338 349 350 + handler = redirect(st.Self.DNSName, true, handler) 351 + 352 + if len(funnel.IP) > 0 { 353 + var allow []netip.Prefix 354 + for _, ip := range funnel.IP { 355 + allow = append(allow, netip.MustParsePrefix(ip)) 356 + } 357 + handler = restrictNetworks(log, allow, handler) 358 + } 359 + 360 + srv = &http.Server{ 361 + Handler: instrument(handler), 362 + ConnContext: func(ctx context.Context, c net.Conn) context.Context { 363 + return context.WithValue(ctx, ctxConn{}, c) 364 + }, 365 + } 366 + 339 367 ln, err := ts.ListenFunnel("tcp", ":443", tsnet.FunnelOnly()) 340 368 if err != nil { 341 369 ··· 344 372 }, func(_ error) { 345 373 if srv != nil { 346 374 if err := srv.Close(); err != nil { 347 - log.Error("TLS server shutdown", lerr(err)) 375 + log.Error("server shutdown", lerr(err)) 348 376 } 349 377 } 350 378 cancel() ··· 356 384 return g.Run() 357 385 } 358 386 359 - // localTailnetHandler serves plain-HTTP on the local tailnet. 360 - func localTailnetHandler(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler { 387 + func redirect(fqdn string, forceSSL bool, next http.Handler) http.Handler { 388 + if fqdn == "" { 389 + panic("redirect: fqdn cannot be empty") 390 + } 391 + fqdn = strings.TrimSuffix(fqdn, ".") 361 392 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 393 + if forceSSL && r.TLS == nil { 394 + http.Redirect(w, r, fmt.Sprintf("https://%s%s", fqdn, r.RequestURI), http.StatusPermanentRedirect) 395 + return 396 + } 397 + 398 + if r.TLS != nil && strings.TrimSuffix(r.Host, ".") != fqdn { 399 + http.Redirect(w, r, fmt.Sprintf("https://%s%s", fqdn, r.RequestURI), http.StatusPermanentRedirect) 400 + return 401 + } 402 + next.ServeHTTP(w, r) 403 + }) 404 + } 405 + 406 + func basicAuth(logger *slog.Logger, user, password string, next http.Handler) http.Handler { 407 + if user == "" || password == "" { 408 + panic("user and password are required") 409 + } 410 + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 411 + u, p, ok := r.BasicAuth() 412 + if ok { 413 + userCheck := subtle.ConstantTimeCompare([]byte(user), []byte(u)) 414 + passwordCheck := subtle.ConstantTimeCompare([]byte(password), []byte(p)) 415 + if userCheck == 1 && passwordCheck == 1 { 416 + next.ServeHTTP(w, r) 417 + return 418 + } 419 + } 420 + logger.ErrorContext(r.Context(), "authentication failed", slog.String("user", u)) 421 + w.Header().Set("WWW-Authenticate", "Basic realm=\"protected\", charset=\"UTF-8\"") 422 + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) 423 + }) 424 + } 425 + 426 + func tailnet(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler { 427 + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 362 428 whois, err := tsWhoIs(lc, r) 363 429 if err != nil { 364 430 ··· 379 445 }) 380 446 } 381 447 382 - // localTailnetTLSHandler serves HTTPS on the local tailnet. 383 - func localTailnetTLSHandler(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler { 384 - var ( 385 - handler = localTailnetHandler(logger, lc, next) 386 - dnsName string 387 - ) 448 + func insecureFunnel(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler { 388 449 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 389 - if r.TLS == nil { 390 - panic("tailnet handler wants tls") 450 + whois, err := tsWhoIs(lc, r) 451 + if err != nil { 452 + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 453 + logger.ErrorContext(r.Context(), "tailscale whois", lerr(err)) 454 + return 391 455 } 392 - 393 - if dnsName == "" { 394 - st, err := lc.StatusWithoutPeers(r.Context()) 395 - if err != nil { 396 - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 397 - logger.ErrorContext(r.Context(), "tailscale status", slog.Any("err", err)) 398 - return 399 - } 400 - dnsName = strings.TrimSuffix(st.Self.DNSName, ".") 401 - } 402 - 403 - if strings.TrimSuffix(r.Host, ".") != dnsName { 404 - http.Redirect(w, r, fmt.Sprintf("https://%s%s", dnsName, r.RequestURI), http.StatusPermanentRedirect) 456 + if !whois.Node.IsTagged() { 457 + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) 458 + logger.ErrorContext(r.Context(), "funnel handler got request from non-tagged node") 405 459 return 406 460 } 407 461 408 - handler.ServeHTTP(w, r) 462 + next.ServeHTTP(w, r) 409 463 }) 410 464 } 411 465 412 - // insecureFunnelHandler handles HTTPS requests coming from Tailscale Funnel nodes. 413 - // This is marked insecure because the upstream is exposed to the public Internet. 414 - // The upstream is responsible for implementing authentication. 415 - func insecureFunnelHandler(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler { 416 - return localTailnetTLSHandler(logger, lc, next) 417 - } 418 - 419 - // oidcFunnelHandlers serves Funnel requests, requiring authentication via the configured OIDC issuer. 420 - func oidcFunnelHandler(ctx context.Context, logger *slog.Logger, lc tailscaleLocalClient, cfg *funnelConfig, next http.Handler) (http.Handler, error) { 421 - st, err := lc.StatusWithoutPeers(ctx) 422 - if err != nil { 423 - return nil, fmt.Errorf("tailscale status: %w", err) 424 - } 425 - 426 - redir := &url.URL{Scheme: "https", Path: ".oidc-callback"} 427 - redir.Host = strings.TrimSuffix(st.Self.DNSName, ".") 428 - 429 - wrapper, err := middleware.NewFromDiscovery(ctx, nil, cfg.Issuer, cfg.ClientID, cfg.ClientSecret, redir.String()) 430 - if err != nil { 431 - return nil, fmt.Errorf("oidc middleware: %w", err) 432 - } 433 - wrapper.OAuth2Config.Scopes = append(wrapper.OAuth2Config.Scopes, oidc.ScopeProfile) 434 - 435 - return wrapper.Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 436 - if r.TLS == nil { 437 - panic("oidc handler wants tls") 438 - } 439 - 440 - _, err := tsWhoIs(lc, r) 466 + func oidcFunnel(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler { 467 + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 468 + whois, err := tsWhoIs(lc, r) 441 469 if err != nil { 442 470 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 443 471 logger.ErrorContext(r.Context(), "tailscale whois", lerr(err)) 444 472 return 445 473 } 474 + if !whois.Node.IsTagged() { 475 + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) 476 + logger.ErrorContext(r.Context(), "funnel handler got request from non-tagged node") 477 + return 478 + } 446 479 447 480 tok := middleware.IDJWTFromContext(r.Context()) 448 481 if tok == nil { 449 - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 482 + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) 450 483 logger.ErrorContext(r.Context(), "jwt token missing") 451 484 return 452 485 } ··· 467 500 req.Header.Set("X-Webauth-User", email) 468 501 req.Header.Set("X-Webauth-Name", name) 469 502 470 - next.ServeHTTP(w, r) 471 - })), nil 503 + next.ServeHTTP(w, req) 504 + }) 472 505 } 473 506 507 + // restrictNetworks will only allow clients from the provided IP networks to 508 + // access the given handler. If skip prefixes are set, paths that match any 509 + // of the regular expressions will not have restrictions applied. 510 + func restrictNetworks(logger *slog.Logger, allowedNetworks []netip.Prefix, next http.Handler) http.Handler { 511 + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 512 + // If the funneled connection is from tsnet, then the net.Conn will be of 513 + // type ipn.FunnelConn. 514 + netConn := r.Context().Value(ctxConn{}) 515 + // if the conn is wrapped inside TLS, unwrap it 516 + if tlsConn, ok := netConn.(*tls.Conn); ok { 517 + netConn = tlsConn.NetConn() 518 + } 519 + var remote netip.AddrPort 520 + if fconn, ok := netConn.(*ipn.FunnelConn); ok { 521 + remote = fconn.Src 522 + } else if v, err := netip.ParseAddrPort(r.RemoteAddr); err == nil { 523 + remote = v 524 + } else { 525 + logger.Error("restrictNetworks: cannot parse client IP:port", lerr(err), slog.String("remote", r.RemoteAddr)) 526 + w.WriteHeader(http.StatusUnauthorized) 527 + return 528 + } 529 + 530 + for _, wl := range allowedNetworks { 531 + if wl.Contains(remote.Addr()) { 532 + next.ServeHTTP(w, r) 533 + return 534 + } 535 + } 536 + 537 + w.WriteHeader(http.StatusForbidden) 538 + _, _ = fmt.Fprint(w, badNetwork) 539 + }) 540 + } 541 + 542 + const badNetwork = ` 543 + <html> 544 + <head><title>Untrusted network</title></head> 545 + <body><h1>Access from untrusted networks not permitted</h1></body> 546 + </html> 547 + ` 548 + 474 549 type tailscaleLocalClient interface { 475 550 WhoIs(context.Context, string) (*apitype.WhoIsResponse, error) 476 - StatusWithoutPeers(context.Context) (*ipnstate.Status, error) 477 551 } 478 552 479 553 func tsWhoIs(lc tailscaleLocalClient, r *http.Request) (*apitype.WhoIsResponse, error) { 554 + 555 + 556 + 557 + 558 + 559 + 560 + 561 + 562 + 563 + 564 + 565 + 566 + 567 + 568 + 569 + 570 + 571 + 572 + 573 + 574 + 575 + 576 + 577 + 578 + 579 + 580 + 581 + 582 + 583 + 584 + 585 + 586 + 587 + 588 + 589 + 590 + 591 + 592 + 593 + 594 + 595 + } 596 + 597 + func lerr(err error) slog.Attr { 598 + return slog.Any("err", err) 599 + }
+71
tsproxy_test.go
··· 310 310 } 311 311 } 312 312 313 + func TestBasicAuthHandler(t *testing.T) { 314 + t.Parallel() 315 + 316 + logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})) 317 + 318 + for _, tc := range []struct { 319 + name string 320 + user string 321 + password string 322 + request func(*http.Request) 323 + wantNext bool 324 + wantStatus int 325 + }{ 326 + { 327 + name: "no basic auth provided", 328 + user: "admin", 329 + password: "secret", 330 + request: func(_ *http.Request) {}, 331 + wantStatus: http.StatusUnauthorized, 332 + }, 333 + { 334 + name: "wrong user", 335 + user: "admin", 336 + password: "secret", 337 + request: func(r *http.Request) { r.SetBasicAuth("bad", "secret") }, 338 + wantStatus: http.StatusUnauthorized, 339 + }, 340 + { 341 + name: "wrong password", 342 + user: "admin", 343 + password: "secret", 344 + request: func(r *http.Request) { r.SetBasicAuth("admin", "bad") }, 345 + wantStatus: http.StatusUnauthorized, 346 + }, 347 + { 348 + name: "ok", 349 + user: "admin", 350 + password: "secret", 351 + request: func(r *http.Request) { r.SetBasicAuth("admin", "secret") }, 352 + wantNext: true, 353 + wantStatus: http.StatusOK, 354 + }, 355 + } { 356 + t.Run(tc.name, func(t *testing.T) { 357 + t.Parallel() 358 + 359 + var nextReq *http.Request 360 + h := basicAuth(logger, tc.user, tc.password, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 361 + nextReq = r 362 + fmt.Fprintf(w, "OK") 363 + })) 364 + w := httptest.NewRecorder() 365 + req := httptest.NewRequest("", "/", nil) 366 + tc.request(req) 367 + h.ServeHTTP(w, req) 368 + resp := w.Result() 369 + 370 + if want, got := tc.wantStatus, resp.StatusCode; want != got { 371 + t.Errorf("want status %d, got: %d", want, got) 372 + } 373 + 374 + if tc.wantNext && nextReq == nil { 375 + t.Fatalf("next handler not called") 376 + } 377 + if !tc.wantNext && nextReq != nil { 378 + t.Fatalf("next handler should not have been called") 379 + } 380 + }) 381 + } 382 + } 383 + 313 384 func TestServeDiscovery(t *testing.T) { 314 385 t.Parallel() 315 386

History

1 round 0 comments
sign up or login to add to the discussion
sr.aux1.dev submitted #0
10 commits
expand
read upstreams from a config file
use go 1.22 loopvar
split server and handler
redirect https to fqdn
implement oidc funnel handler
revamp http handlers implementation and tests
fix panic in lerr when err is nil
add basic auth funnel handler
fix "upstream error" log line to include the upstream
add IP allowlist support for funnel
no conflicts, ready to merge
expand 0 comments