HTTP reverse proxy for Tailscale

revamp http handlers implementation and tests

+289 -164
+1 -1
go.mod
··· 6 6 7 7 require ( 8 8 github.com/google/go-cmp v0.7.0 9 + github.com/lstoll/oidc v1.0.0-beta.4.0.20250106123456-6ffce62670fe 9 10 github.com/oklog/run v1.1.0 10 11 github.com/prometheus/client_golang v1.21.1 11 12 github.com/prometheus/common v0.63.0 ··· 56 57 github.com/klauspost/compress v1.17.11 // indirect 57 58 github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a // indirect 58 59 github.com/kylelemons/godebug v1.1.0 // indirect 59 - github.com/lstoll/oidc v1.0.0-beta.4.0.20250106123456-6ffce62670fe // indirect 60 60 github.com/mdlayher/genetlink v1.3.2 // indirect 61 61 github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect 62 62 github.com/mdlayher/sdnotify v1.0.0 // indirect
+63 -78
main.go tsproxy.go
··· 30 30 "github.com/tailscale/hujson" 31 31 "tailscale.com/client/local" 32 32 "tailscale.com/client/tailscale/apitype" 33 - "tailscale.com/ipn/ipnstate" 34 33 "tailscale.com/tsnet" 35 34 tslogger "tailscale.com/types/logger" 36 35 ) ··· 69 68 Name string 70 69 Backend string 71 70 Prometheus bool 72 - Funnel *funnelConfig `json:"funnel,omitempty"` 71 + Funnel *funnelConfig 73 72 } 74 73 75 74 type funnelConfig struct { ··· 265 264 return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err) 266 265 } 267 266 268 - srv = &http.Server{Handler: instrument(localTailnetHandler(log, lc, proxy))} 267 + srv = &http.Server{Handler: instrument(redirect(st.Self.DNSName, false, tailnet(log, lc, proxy)))} 269 268 ln, err := ts.Listen("tcp", ":80") 270 269 if err != nil { 271 270 return fmt.Errorf("tailscale: listen for %s on port 80: %w", upstream.Name, err) ··· 287 286 { 288 287 var srv *http.Server 289 288 g.Add(func() error { 290 - _, err := ts.Up(ctx) 289 + st, err := ts.Up(ctx) 291 290 if err != nil { 292 291 return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err) 293 292 } 293 + 294 294 srv = &http.Server{ 295 295 TLSConfig: &tls.Config{GetCertificate: lc.GetCertificate}, 296 - Handler: instrument(localTailnetTLSHandler(log, lc, proxy)), 296 + Handler: instrument(redirect(st.Self.DNSName, true, tailnet(log, lc, proxy))), 297 297 } 298 298 299 299 ln, err := ts.Listen("tcp", ":443") ··· 304 304 }, func(_ error) { 305 305 if srv != nil { 306 306 if err := srv.Close(); err != nil { 307 - log.Error("TLS server shutdown", lerr(err)) 307 + log.Error("server shutdown", lerr(err)) 308 308 } 309 309 } 310 310 cancel() 311 311 }) 312 312 } 313 313 if funnel := upstream.Funnel; funnel != nil { 314 - if !funnel.Insecure && funnel.Issuer == "" { 315 - return fmt.Errorf("upstream %s: funnel must set issuer or insecure", upstream.Name) 316 - } 317 314 { 318 315 var srv *http.Server 319 316 g.Add(func() error { 320 - _, err := ts.Up(ctx) 317 + st, err := ts.Up(ctx) 321 318 if err != nil { 322 319 return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err) 323 320 } 324 321 325 - srv = &http.Server{} 326 - if funnel.Issuer != "" { 327 - handler, err := oidcFunnelHandler(ctx, log, lc, funnel, proxy) 322 + var handler http.Handler 323 + switch { 324 + case funnel.Insecure: 325 + handler = insecureFunnel(log, lc, proxy) 326 + case funnel.Issuer != "": 327 + redir := &url.URL{Scheme: "https", Host: strings.TrimSuffix(st.Self.DNSName, "."), Path: ".oidc-callback"} 328 + wrapper, err := middleware.NewFromDiscovery(ctx, nil, funnel.Issuer, funnel.ClientID, funnel.ClientSecret, redir.String()) 328 329 if err != nil { 329 - return fmt.Errorf("oidc: %w", err) 330 + return fmt.Errorf("oidc middleware for %s: %w", upstream.Name, err) 330 331 } 331 - srv.Handler = handler 332 - } else if funnel.Insecure { 333 - srv.Handler = insecureFunnelHandler(log, lc, proxy) 334 - } else { 335 - panic("funnel misconfigured") 332 + wrapper.OAuth2Config.Scopes = append(wrapper.OAuth2Config.Scopes, oidc.ScopeProfile) 333 + 334 + handler = wrapper.Wrap(oidcFunnel(log, lc, proxy)) 335 + default: 336 + return fmt.Errorf("upstream %s must set funnel.insecure or funnel.issuer", upstream.Name) 336 337 } 337 - srv.Handler = instrument(srv.Handler) 338 + srv = &http.Server{Handler: instrument(redirect(st.Self.DNSName, true, handler))} 338 339 339 340 ln, err := ts.ListenFunnel("tcp", ":443", tsnet.FunnelOnly()) 340 341 if err != nil { ··· 344 345 }, func(_ error) { 345 346 if srv != nil { 346 347 if err := srv.Close(); err != nil { 347 - log.Error("TLS server shutdown", lerr(err)) 348 + log.Error("server shutdown", lerr(err)) 348 349 } 349 350 } 350 351 cancel() ··· 356 357 return g.Run() 357 358 } 358 359 359 - // localTailnetHandler serves plain-HTTP on the local tailnet. 360 - func localTailnetHandler(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler { 360 + func redirect(fqdn string, forceSSL bool, next http.Handler) http.Handler { 361 + if fqdn == "" { 362 + panic("redirect: fqdn cannot be empty") 363 + } 364 + fqdn = strings.TrimSuffix(fqdn, ".") 365 + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 366 + if forceSSL && r.TLS == nil { 367 + http.Redirect(w, r, fmt.Sprintf("https://%s%s", fqdn, r.RequestURI), http.StatusPermanentRedirect) 368 + return 369 + } 370 + 371 + if r.TLS != nil && strings.TrimSuffix(r.Host, ".") != fqdn { 372 + http.Redirect(w, r, fmt.Sprintf("https://%s%s", fqdn, r.RequestURI), http.StatusPermanentRedirect) 373 + return 374 + } 375 + next.ServeHTTP(w, r) 376 + }) 377 + } 378 + 379 + func tailnet(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler { 361 380 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 362 381 whois, err := tsWhoIs(lc, r) 363 382 if err != nil { ··· 379 398 }) 380 399 } 381 400 382 - // localTailnetTLSHandler serves HTTPS on the local tailnet. 383 - func localTailnetTLSHandler(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler { 384 - var ( 385 - handler = localTailnetHandler(logger, lc, next) 386 - dnsName string 387 - ) 401 + func insecureFunnel(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler { 388 402 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 389 - if r.TLS == nil { 390 - panic("tailnet handler wants tls") 403 + whois, err := tsWhoIs(lc, r) 404 + if err != nil { 405 + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 406 + logger.ErrorContext(r.Context(), "tailscale whois", lerr(err)) 407 + return 391 408 } 392 - 393 - if dnsName == "" { 394 - st, err := lc.StatusWithoutPeers(r.Context()) 395 - if err != nil { 396 - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 397 - logger.ErrorContext(r.Context(), "tailscale status", slog.Any("err", err)) 398 - return 399 - } 400 - dnsName = strings.TrimSuffix(st.Self.DNSName, ".") 401 - } 402 - 403 - if strings.TrimSuffix(r.Host, ".") != dnsName { 404 - http.Redirect(w, r, fmt.Sprintf("https://%s%s", dnsName, r.RequestURI), http.StatusPermanentRedirect) 409 + if !whois.Node.IsTagged() { 410 + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) 411 + logger.ErrorContext(r.Context(), "funnel handler got request from non-tagged node") 405 412 return 406 413 } 407 414 408 - handler.ServeHTTP(w, r) 415 + next.ServeHTTP(w, r) 409 416 }) 410 417 } 411 418 412 - // insecureFunnelHandler handles HTTPS requests coming from Tailscale Funnel nodes. 413 - // This is marked insecure because the upstream is exposed to the public Internet. 414 - // The upstream is responsible for implementing authentication. 415 - func insecureFunnelHandler(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler { 416 - return localTailnetTLSHandler(logger, lc, next) 417 - } 418 - 419 - // oidcFunnelHandlers serves Funnel requests, requiring authentication via the configured OIDC issuer. 420 - func oidcFunnelHandler(ctx context.Context, logger *slog.Logger, lc tailscaleLocalClient, cfg *funnelConfig, next http.Handler) (http.Handler, error) { 421 - st, err := lc.StatusWithoutPeers(ctx) 422 - if err != nil { 423 - return nil, fmt.Errorf("tailscale status: %w", err) 424 - } 425 - 426 - redir := &url.URL{Scheme: "https", Path: ".oidc-callback"} 427 - redir.Host = strings.TrimSuffix(st.Self.DNSName, ".") 428 - 429 - wrapper, err := middleware.NewFromDiscovery(ctx, nil, cfg.Issuer, cfg.ClientID, cfg.ClientSecret, redir.String()) 430 - if err != nil { 431 - return nil, fmt.Errorf("oidc middleware: %w", err) 432 - } 433 - wrapper.OAuth2Config.Scopes = append(wrapper.OAuth2Config.Scopes, oidc.ScopeProfile) 434 - 435 - return wrapper.Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 436 - if r.TLS == nil { 437 - panic("oidc handler wants tls") 438 - } 439 - 440 - _, err := tsWhoIs(lc, r) 419 + func oidcFunnel(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler { 420 + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 421 + whois, err := tsWhoIs(lc, r) 441 422 if err != nil { 442 423 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 443 424 logger.ErrorContext(r.Context(), "tailscale whois", lerr(err)) 444 425 return 445 426 } 427 + if !whois.Node.IsTagged() { 428 + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) 429 + logger.ErrorContext(r.Context(), "funnel handler got request from non-tagged node") 430 + return 431 + } 446 432 447 433 tok := middleware.IDJWTFromContext(r.Context()) 448 434 if tok == nil { 449 - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 435 + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) 450 436 logger.ErrorContext(r.Context(), "jwt token missing") 451 437 return 452 438 } ··· 467 453 req.Header.Set("X-Webauth-User", email) 468 454 req.Header.Set("X-Webauth-Name", name) 469 455 470 - next.ServeHTTP(w, r) 471 - })), nil 456 + next.ServeHTTP(w, req) 457 + }) 472 458 } 473 459 474 460 type tailscaleLocalClient interface { 475 461 WhoIs(context.Context, string) (*apitype.WhoIsResponse, error) 476 - StatusWithoutPeers(context.Context) (*ipnstate.Status, error) 477 462 } 478 463 479 464 func tsWhoIs(lc tailscaleLocalClient, r *http.Request) (*apitype.WhoIsResponse, error) {
+225 -85
tsproxy_test.go
··· 8 8 "log/slog" 9 9 "net/http" 10 10 "net/http/httptest" 11 + "strings" 11 12 "testing" 12 13 13 14 "github.com/google/go-cmp/cmp" 14 15 "github.com/prometheus/client_golang/prometheus" 15 16 "github.com/prometheus/client_golang/prometheus/testutil" 16 17 "tailscale.com/client/tailscale/apitype" 17 - "tailscale.com/ipn/ipnstate" 18 18 "tailscale.com/tailcfg" 19 19 ) 20 20 21 21 type fakeLocalClient struct { 22 - whois func(context.Context, string) (*apitype.WhoIsResponse, error) 23 - status func(context.Context) (*ipnstate.Status, error) 22 + whois func(context.Context, string) (*apitype.WhoIsResponse, error) 24 23 } 25 24 26 25 func (c *fakeLocalClient) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) { 26 + if c.whois == nil { 27 + return nil, errors.New("not implemented") 28 + } 27 29 return c.whois(ctx, remoteAddr) 28 30 } 29 31 30 - func (c *fakeLocalClient) StatusWithoutPeers(ctx context.Context) (*ipnstate.Status, error) { 31 - return c.status(ctx) 32 - } 33 - 34 - func TestLocalTailnetHandler(t *testing.T) { 32 + func TestTSHandlers(t *testing.T) { 35 33 t.Parallel() 36 34 35 + logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})) 36 + 37 37 for _, tc := range []struct { 38 38 name string 39 39 whois func(context.Context, string) (*apitype.WhoIsResponse, error) 40 - want int 40 + handler func(*slog.Logger, tailscaleLocalClient, http.Handler) http.Handler 41 + wantNext bool 42 + wantStatus int 41 43 wantHeaders map[string]string 44 + wantBody string 42 45 }{ 43 46 { 44 - name: "tailscale whois error", 47 + name: "tailnet: tailscale whois error", 48 + handler: tailnet, 45 49 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 46 50 return nil, errors.New("whois error") 47 51 }, 48 - want: http.StatusInternalServerError, 52 + wantStatus: http.StatusInternalServerError, 53 + wantBody: "Internal Server Error", 49 54 }, 50 55 { 51 - name: "tailscale whois no profile", 56 + name: "tailnet: tailscale whois no profile", 57 + handler: tailnet, 52 58 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 53 - return &apitype.WhoIsResponse{}, nil 59 + return &apitype.WhoIsResponse{Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil 54 60 }, 55 - want: http.StatusInternalServerError, 61 + wantStatus: http.StatusInternalServerError, 62 + wantBody: "Internal Server Error", 56 63 }, 57 64 { 58 - name: "tailscale whois no node", 65 + name: "tailnet: tailscale whois no node", 66 + handler: tailnet, 59 67 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 60 68 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login"}}, nil 61 69 }, 62 - want: http.StatusInternalServerError, 70 + wantStatus: http.StatusInternalServerError, 71 + wantBody: "Internal Server Error", 63 72 }, 64 73 { 65 - name: "tailscale whois ok (tagged node)", 74 + name: "tailnet: tailscale whois ok (tagged node)", 75 + handler: tailnet, 66 76 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 67 77 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "tagged-devices"}, Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil 68 78 }, 69 - want: http.StatusOK, 79 + wantNext: true, 80 + wantStatus: http.StatusOK, 81 + wantBody: "OK", 82 + wantHeaders: map[string]string{ 83 + "X-Webauth-User": "", 84 + "X-Webauth-Name": "", 85 + }, 70 86 }, 71 87 { 72 - name: "tailscale whois ok (user)", 88 + name: "tailnet: tailscale whois ok (user)", 89 + handler: tailnet, 73 90 whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 74 91 return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login", DisplayName: "name"}, Node: &tailcfg.Node{Name: "login.ts.net"}}, nil 75 92 }, 76 - want: http.StatusOK, 93 + wantNext: true, 94 + wantStatus: http.StatusOK, 95 + wantBody: "OK", 77 96 wantHeaders: map[string]string{ 78 97 "X-Webauth-User": "login", 79 98 "X-Webauth-Name": "name", 80 99 }, 81 100 }, 101 + { 102 + name: "insecure: tailscale whois error", 103 + handler: insecureFunnel, 104 + whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 105 + return nil, errors.New("whois error") 106 + }, 107 + wantStatus: http.StatusInternalServerError, 108 + wantBody: "Internal Server Error", 109 + }, 110 + { 111 + name: "insecure: tailscale whois no profile", 112 + handler: insecureFunnel, 113 + whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 114 + return &apitype.WhoIsResponse{Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil 115 + }, 116 + wantStatus: http.StatusInternalServerError, 117 + wantBody: "Internal Server Error", 118 + }, 119 + { 120 + name: "insure: tailscale whois no node", 121 + handler: insecureFunnel, 122 + whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 123 + return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login"}}, nil 124 + }, 125 + wantStatus: http.StatusInternalServerError, 126 + wantBody: "Internal Server Error", 127 + }, 128 + { 129 + name: "insecure: tagged node", 130 + handler: insecureFunnel, 131 + whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 132 + return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "tagged-devices"}, Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil 133 + }, 134 + wantNext: true, 135 + wantStatus: http.StatusOK, 136 + wantBody: "OK", 137 + wantHeaders: map[string]string{ 138 + "X-Webauth-User": "", 139 + "X-Webauth-Name": "", 140 + }, 141 + }, 142 + { 143 + name: "insecure: user node", 144 + handler: insecureFunnel, 145 + whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 146 + return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login", DisplayName: "name"}, Node: &tailcfg.Node{Name: "login.ts.net"}}, nil 147 + }, 148 + wantStatus: http.StatusUnauthorized, 149 + wantBody: "Unauthorized", 150 + }, 151 + { 152 + name: "oidc: tailscale whois error", 153 + handler: oidcFunnel, 154 + whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 155 + return nil, errors.New("whois error") 156 + }, 157 + wantStatus: http.StatusInternalServerError, 158 + wantBody: "Internal Server Error", 159 + }, 160 + { 161 + name: "oidc: tailscale whois no profile", 162 + handler: oidcFunnel, 163 + whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 164 + return &apitype.WhoIsResponse{Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil 165 + }, 166 + wantStatus: http.StatusInternalServerError, 167 + wantBody: "Internal Server Error", 168 + }, 169 + { 170 + name: "oidc: tailscale whois no node", 171 + handler: oidcFunnel, 172 + whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 173 + return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login"}}, nil 174 + }, 175 + wantStatus: http.StatusInternalServerError, 176 + wantBody: "Internal Server Error", 177 + }, 178 + { 179 + name: "oidc: user node", 180 + handler: oidcFunnel, 181 + whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 182 + return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login", DisplayName: "name"}, Node: &tailcfg.Node{Name: "login.ts.net"}}, nil 183 + }, 184 + wantStatus: http.StatusUnauthorized, 185 + wantBody: "Unauthorized", 186 + }, 187 + { 188 + name: "oidc: tagged node", 189 + handler: oidcFunnel, 190 + whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 191 + return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "tagged-devices"}, Node: &tailcfg.Node{Tags: []string{"tag:ingress"}}}, nil 192 + }, 193 + wantStatus: http.StatusUnauthorized, 194 + wantBody: "Unauthorized", 195 + }, 82 196 } { 83 - tc := tc 84 197 t.Run(tc.name, func(t *testing.T) { 85 198 t.Parallel() 86 - lc := &fakeLocalClient{whois: tc.whois} 87 - be := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 88 - for k, v := range r.Header { 89 - w.Header().Set(k, v[0]) 90 - } 91 - fmt.Fprintln(w, "Hi from the backend.") 92 - }) 93 - px := httptest.NewServer(localTailnetHandler(slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), lc, be)) 94 - defer px.Close() 199 + 200 + var nextReq *http.Request 201 + h := tc.handler(logger, &fakeLocalClient{whois: tc.whois}, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 202 + nextReq = r 203 + fmt.Fprintf(w, "OK") 204 + })) 205 + w := httptest.NewRecorder() 206 + h.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "http://example.com/path", nil)) 207 + resp := w.Result() 208 + 209 + if want, got := tc.wantStatus, resp.StatusCode; want != got { 210 + t.Errorf("want status %d, got: %d", want, got) 211 + } 95 212 96 - resp, err := http.Get(px.URL) 213 + body, err := io.ReadAll(resp.Body) 97 214 if err != nil { 98 215 t.Fatal(err) 99 216 } 100 - defer resp.Body.Close() 101 - 102 - if want, got := tc.want, resp.StatusCode; want != got { 103 - t.Errorf("want status %d, got: %d", want, got) 217 + if !strings.Contains(string(body), tc.wantBody) { 218 + t.Errorf("want body %q, got: %q", tc.wantBody, string(body)) 104 219 } 105 - if tc.wantHeaders == nil { 106 - tc.wantHeaders = map[string]string{ 107 - "X-Webauth-User": "", 108 - "X-Webauth-Name": "", 109 - } 220 + if tc.wantNext && nextReq == nil { 221 + t.Fatalf("next handler not called") 110 222 } 111 223 for k, want := range tc.wantHeaders { 112 - if got := resp.Header.Get(k); got != want { 113 - t.Errorf("want header %s %s, got: %s", k, want, got) 224 + if got := nextReq.Header.Get(k); got != want { 225 + t.Errorf("want header %s = %s, got: %s", k, want, got) 114 226 } 115 227 } 116 228 }) 117 229 } 118 230 } 119 231 120 - func TestLocalTailnetTLSHandler(t *testing.T) { 232 + func TestRedirectHandler(t *testing.T) { 121 233 t.Parallel() 122 234 123 - lc := &fakeLocalClient{ 124 - whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) { 125 - return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "tagged-devices"}, Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil 235 + for _, tc := range []struct { 236 + name string 237 + forceSSL bool 238 + fqdn string 239 + request *http.Request 240 + wantNext bool 241 + wantStatus int 242 + wantLocation string 243 + }{ 244 + { 245 + name: "forceSSL: redirect", 246 + forceSSL: true, 247 + fqdn: "http://example.com", 248 + request: httptest.NewRequest("", "/path", nil), 249 + wantStatus: http.StatusPermanentRedirect, 250 + wantLocation: "https://example.com/path", 251 + }, 252 + { 253 + name: "forceSSL: ok", 254 + forceSSL: true, 255 + fqdn: "example.com", 256 + request: httptest.NewRequest("", "https://example.com/path", nil), 257 + wantNext: true, 258 + wantStatus: http.StatusOK, 259 + }, 260 + { 261 + name: "fqdn: redirect", 262 + fqdn: "example.ts.net", 263 + request: httptest.NewRequest("", "https://example/path", nil), 264 + wantStatus: http.StatusPermanentRedirect, 265 + wantLocation: "https://example.ts.net/path", 266 + }, 267 + { 268 + name: "fqdn: ok", 269 + fqdn: "example.ts.net", 270 + request: httptest.NewRequest("", "https://example.ts.net/path", nil), 271 + wantNext: true, 272 + wantStatus: http.StatusOK, 126 273 }, 127 - status: func(_ context.Context) (*ipnstate.Status, error) { 128 - return &ipnstate.Status{Self: &ipnstate.PeerStatus{DNSName: "foo.ts.net."}}, nil 274 + { 275 + name: "fqdn: ok (not tls)", 276 + fqdn: "example.ts.net", 277 + request: httptest.NewRequest("", "/path", nil), 278 + wantNext: true, 279 + wantStatus: http.StatusOK, 129 280 }, 130 - } 131 - be := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 132 - fmt.Fprintln(w, "Hi from the backend.") 133 - }) 134 - px := httptest.NewTLSServer(localTailnetTLSHandler(slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), lc, be)) 135 - defer px.Close() 281 + } { 282 + t.Run(tc.name, func(t *testing.T) { 283 + t.Parallel() 136 284 137 - cli := px.Client() 138 - cli.CheckRedirect = func(_ *http.Request, _ []*http.Request) error { 139 - return http.ErrUseLastResponse 140 - } 285 + var nextReq *http.Request 286 + h := redirect(tc.fqdn, tc.forceSSL, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 287 + nextReq = r 288 + fmt.Fprintf(w, "OK") 289 + })) 290 + w := httptest.NewRecorder() 291 + h.ServeHTTP(w, tc.request) 292 + resp := w.Result() 141 293 142 - req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, px.URL+"/bar", nil) 143 - if err != nil { 144 - t.Fatal(err) 145 - } 146 - resp, err := cli.Do(req) 147 - if err != nil { 148 - t.Fatal(err) 149 - } 150 - defer resp.Body.Close() 151 - if want, got := http.StatusPermanentRedirect, resp.StatusCode; want != got { 152 - t.Fatalf("want status %d, got: %d", want, got) 153 - } 154 - if want, got := "https://foo.ts.net/bar", resp.Header.Get("location"); got != want { 155 - t.Fatalf("want Location %s, got: %s", want, got) 156 - } 294 + if want, got := tc.wantStatus, resp.StatusCode; want != got { 295 + t.Errorf("want status %d, got: %d", want, got) 296 + } 157 297 158 - req, err = http.NewRequestWithContext(t.Context(), http.MethodGet, px.URL, nil) 159 - if err != nil { 160 - t.Fatal(err) 161 - } 162 - req.Host = "foo.ts.net" 163 - resp, err = px.Client().Do(req) 164 - if err != nil { 165 - t.Fatal(err) 166 - } 167 - defer resp.Body.Close() 168 - if want, got := http.StatusOK, resp.StatusCode; want != got { 169 - t.Fatalf("want status %d, got: %d", want, got) 298 + if tc.wantNext && nextReq == nil { 299 + t.Fatalf("next handler not called") 300 + } 301 + if !tc.wantNext && nextReq != nil { 302 + t.Fatalf("next handler was called") 303 + } 304 + if nextReq != nil { 305 + if want, got := tc.wantLocation, nextReq.Header.Get("Location"); got != want { 306 + t.Errorf("want Location header %s, got: %s", want, got) 307 + } 308 + } 309 + }) 170 310 } 171 311 } 172 312