tangled
alpha
login
or
join now
sr.aux1.dev
/
tsproxy
2
fork
atom
HTTP reverse proxy for Tailscale
2
fork
atom
overview
issues
pulls
1
pipelines
revamp http handlers implementation and tests
sr.aux1.dev
9 months ago
12436389
4d24fc97
+289
-164
3 changed files
expand all
collapse all
unified
split
go.mod
tsproxy.go
tsproxy_test.go
+1
-1
go.mod
···
6
6
7
7
require (
8
8
github.com/google/go-cmp v0.7.0
9
9
+
github.com/lstoll/oidc v1.0.0-beta.4.0.20250106123456-6ffce62670fe
9
10
github.com/oklog/run v1.1.0
10
11
github.com/prometheus/client_golang v1.21.1
11
12
github.com/prometheus/common v0.63.0
···
56
57
github.com/klauspost/compress v1.17.11 // indirect
57
58
github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a // indirect
58
59
github.com/kylelemons/godebug v1.1.0 // indirect
59
59
-
github.com/lstoll/oidc v1.0.0-beta.4.0.20250106123456-6ffce62670fe // indirect
60
60
github.com/mdlayher/genetlink v1.3.2 // indirect
61
61
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect
62
62
github.com/mdlayher/sdnotify v1.0.0 // indirect
+63
-78
main.go
tsproxy.go
···
30
30
"github.com/tailscale/hujson"
31
31
"tailscale.com/client/local"
32
32
"tailscale.com/client/tailscale/apitype"
33
33
-
"tailscale.com/ipn/ipnstate"
34
33
"tailscale.com/tsnet"
35
34
tslogger "tailscale.com/types/logger"
36
35
)
···
69
68
Name string
70
69
Backend string
71
70
Prometheus bool
72
72
-
Funnel *funnelConfig `json:"funnel,omitempty"`
71
71
+
Funnel *funnelConfig
73
72
}
74
73
75
74
type funnelConfig struct {
···
265
264
return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err)
266
265
}
267
266
268
268
-
srv = &http.Server{Handler: instrument(localTailnetHandler(log, lc, proxy))}
267
267
+
srv = &http.Server{Handler: instrument(redirect(st.Self.DNSName, false, tailnet(log, lc, proxy)))}
269
268
ln, err := ts.Listen("tcp", ":80")
270
269
if err != nil {
271
270
return fmt.Errorf("tailscale: listen for %s on port 80: %w", upstream.Name, err)
···
287
286
{
288
287
var srv *http.Server
289
288
g.Add(func() error {
290
290
-
_, err := ts.Up(ctx)
289
289
+
st, err := ts.Up(ctx)
291
290
if err != nil {
292
291
return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err)
293
292
}
293
293
+
294
294
srv = &http.Server{
295
295
TLSConfig: &tls.Config{GetCertificate: lc.GetCertificate},
296
296
-
Handler: instrument(localTailnetTLSHandler(log, lc, proxy)),
296
296
+
Handler: instrument(redirect(st.Self.DNSName, true, tailnet(log, lc, proxy))),
297
297
}
298
298
299
299
ln, err := ts.Listen("tcp", ":443")
···
304
304
}, func(_ error) {
305
305
if srv != nil {
306
306
if err := srv.Close(); err != nil {
307
307
-
log.Error("TLS server shutdown", lerr(err))
307
307
+
log.Error("server shutdown", lerr(err))
308
308
}
309
309
}
310
310
cancel()
311
311
})
312
312
}
313
313
if funnel := upstream.Funnel; funnel != nil {
314
314
-
if !funnel.Insecure && funnel.Issuer == "" {
315
315
-
return fmt.Errorf("upstream %s: funnel must set issuer or insecure", upstream.Name)
316
316
-
}
317
314
{
318
315
var srv *http.Server
319
316
g.Add(func() error {
320
320
-
_, err := ts.Up(ctx)
317
317
+
st, err := ts.Up(ctx)
321
318
if err != nil {
322
319
return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err)
323
320
}
324
321
325
325
-
srv = &http.Server{}
326
326
-
if funnel.Issuer != "" {
327
327
-
handler, err := oidcFunnelHandler(ctx, log, lc, funnel, proxy)
322
322
+
var handler http.Handler
323
323
+
switch {
324
324
+
case funnel.Insecure:
325
325
+
handler = insecureFunnel(log, lc, proxy)
326
326
+
case funnel.Issuer != "":
327
327
+
redir := &url.URL{Scheme: "https", Host: strings.TrimSuffix(st.Self.DNSName, "."), Path: ".oidc-callback"}
328
328
+
wrapper, err := middleware.NewFromDiscovery(ctx, nil, funnel.Issuer, funnel.ClientID, funnel.ClientSecret, redir.String())
328
329
if err != nil {
329
329
-
return fmt.Errorf("oidc: %w", err)
330
330
+
return fmt.Errorf("oidc middleware for %s: %w", upstream.Name, err)
330
331
}
331
331
-
srv.Handler = handler
332
332
-
} else if funnel.Insecure {
333
333
-
srv.Handler = insecureFunnelHandler(log, lc, proxy)
334
334
-
} else {
335
335
-
panic("funnel misconfigured")
332
332
+
wrapper.OAuth2Config.Scopes = append(wrapper.OAuth2Config.Scopes, oidc.ScopeProfile)
333
333
+
334
334
+
handler = wrapper.Wrap(oidcFunnel(log, lc, proxy))
335
335
+
default:
336
336
+
return fmt.Errorf("upstream %s must set funnel.insecure or funnel.issuer", upstream.Name)
336
337
}
337
337
-
srv.Handler = instrument(srv.Handler)
338
338
+
srv = &http.Server{Handler: instrument(redirect(st.Self.DNSName, true, handler))}
338
339
339
340
ln, err := ts.ListenFunnel("tcp", ":443", tsnet.FunnelOnly())
340
341
if err != nil {
···
344
345
}, func(_ error) {
345
346
if srv != nil {
346
347
if err := srv.Close(); err != nil {
347
347
-
log.Error("TLS server shutdown", lerr(err))
348
348
+
log.Error("server shutdown", lerr(err))
348
349
}
349
350
}
350
351
cancel()
···
356
357
return g.Run()
357
358
}
358
359
359
359
-
// localTailnetHandler serves plain-HTTP on the local tailnet.
360
360
-
func localTailnetHandler(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler {
360
360
+
func redirect(fqdn string, forceSSL bool, next http.Handler) http.Handler {
361
361
+
if fqdn == "" {
362
362
+
panic("redirect: fqdn cannot be empty")
363
363
+
}
364
364
+
fqdn = strings.TrimSuffix(fqdn, ".")
365
365
+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
366
366
+
if forceSSL && r.TLS == nil {
367
367
+
http.Redirect(w, r, fmt.Sprintf("https://%s%s", fqdn, r.RequestURI), http.StatusPermanentRedirect)
368
368
+
return
369
369
+
}
370
370
+
371
371
+
if r.TLS != nil && strings.TrimSuffix(r.Host, ".") != fqdn {
372
372
+
http.Redirect(w, r, fmt.Sprintf("https://%s%s", fqdn, r.RequestURI), http.StatusPermanentRedirect)
373
373
+
return
374
374
+
}
375
375
+
next.ServeHTTP(w, r)
376
376
+
})
377
377
+
}
378
378
+
379
379
+
func tailnet(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler {
361
380
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
362
381
whois, err := tsWhoIs(lc, r)
363
382
if err != nil {
···
379
398
})
380
399
}
381
400
382
382
-
// localTailnetTLSHandler serves HTTPS on the local tailnet.
383
383
-
func localTailnetTLSHandler(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler {
384
384
-
var (
385
385
-
handler = localTailnetHandler(logger, lc, next)
386
386
-
dnsName string
387
387
-
)
401
401
+
func insecureFunnel(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler {
388
402
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
389
389
-
if r.TLS == nil {
390
390
-
panic("tailnet handler wants tls")
403
403
+
whois, err := tsWhoIs(lc, r)
404
404
+
if err != nil {
405
405
+
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
406
406
+
logger.ErrorContext(r.Context(), "tailscale whois", lerr(err))
407
407
+
return
391
408
}
392
392
-
393
393
-
if dnsName == "" {
394
394
-
st, err := lc.StatusWithoutPeers(r.Context())
395
395
-
if err != nil {
396
396
-
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
397
397
-
logger.ErrorContext(r.Context(), "tailscale status", slog.Any("err", err))
398
398
-
return
399
399
-
}
400
400
-
dnsName = strings.TrimSuffix(st.Self.DNSName, ".")
401
401
-
}
402
402
-
403
403
-
if strings.TrimSuffix(r.Host, ".") != dnsName {
404
404
-
http.Redirect(w, r, fmt.Sprintf("https://%s%s", dnsName, r.RequestURI), http.StatusPermanentRedirect)
409
409
+
if !whois.Node.IsTagged() {
410
410
+
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
411
411
+
logger.ErrorContext(r.Context(), "funnel handler got request from non-tagged node")
405
412
return
406
413
}
407
414
408
408
-
handler.ServeHTTP(w, r)
415
415
+
next.ServeHTTP(w, r)
409
416
})
410
417
}
411
418
412
412
-
// insecureFunnelHandler handles HTTPS requests coming from Tailscale Funnel nodes.
413
413
-
// This is marked insecure because the upstream is exposed to the public Internet.
414
414
-
// The upstream is responsible for implementing authentication.
415
415
-
func insecureFunnelHandler(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler {
416
416
-
return localTailnetTLSHandler(logger, lc, next)
417
417
-
}
418
418
-
419
419
-
// oidcFunnelHandlers serves Funnel requests, requiring authentication via the configured OIDC issuer.
420
420
-
func oidcFunnelHandler(ctx context.Context, logger *slog.Logger, lc tailscaleLocalClient, cfg *funnelConfig, next http.Handler) (http.Handler, error) {
421
421
-
st, err := lc.StatusWithoutPeers(ctx)
422
422
-
if err != nil {
423
423
-
return nil, fmt.Errorf("tailscale status: %w", err)
424
424
-
}
425
425
-
426
426
-
redir := &url.URL{Scheme: "https", Path: ".oidc-callback"}
427
427
-
redir.Host = strings.TrimSuffix(st.Self.DNSName, ".")
428
428
-
429
429
-
wrapper, err := middleware.NewFromDiscovery(ctx, nil, cfg.Issuer, cfg.ClientID, cfg.ClientSecret, redir.String())
430
430
-
if err != nil {
431
431
-
return nil, fmt.Errorf("oidc middleware: %w", err)
432
432
-
}
433
433
-
wrapper.OAuth2Config.Scopes = append(wrapper.OAuth2Config.Scopes, oidc.ScopeProfile)
434
434
-
435
435
-
return wrapper.Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
436
436
-
if r.TLS == nil {
437
437
-
panic("oidc handler wants tls")
438
438
-
}
439
439
-
440
440
-
_, err := tsWhoIs(lc, r)
419
419
+
func oidcFunnel(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler {
420
420
+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
421
421
+
whois, err := tsWhoIs(lc, r)
441
422
if err != nil {
442
423
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
443
424
logger.ErrorContext(r.Context(), "tailscale whois", lerr(err))
444
425
return
445
426
}
427
427
+
if !whois.Node.IsTagged() {
428
428
+
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
429
429
+
logger.ErrorContext(r.Context(), "funnel handler got request from non-tagged node")
430
430
+
return
431
431
+
}
446
432
447
433
tok := middleware.IDJWTFromContext(r.Context())
448
434
if tok == nil {
449
449
-
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
435
435
+
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
450
436
logger.ErrorContext(r.Context(), "jwt token missing")
451
437
return
452
438
}
···
467
453
req.Header.Set("X-Webauth-User", email)
468
454
req.Header.Set("X-Webauth-Name", name)
469
455
470
470
-
next.ServeHTTP(w, r)
471
471
-
})), nil
456
456
+
next.ServeHTTP(w, req)
457
457
+
})
472
458
}
473
459
474
460
type tailscaleLocalClient interface {
475
461
WhoIs(context.Context, string) (*apitype.WhoIsResponse, error)
476
476
-
StatusWithoutPeers(context.Context) (*ipnstate.Status, error)
477
462
}
478
463
479
464
func tsWhoIs(lc tailscaleLocalClient, r *http.Request) (*apitype.WhoIsResponse, error) {
+225
-85
tsproxy_test.go
···
8
8
"log/slog"
9
9
"net/http"
10
10
"net/http/httptest"
11
11
+
"strings"
11
12
"testing"
12
13
13
14
"github.com/google/go-cmp/cmp"
14
15
"github.com/prometheus/client_golang/prometheus"
15
16
"github.com/prometheus/client_golang/prometheus/testutil"
16
17
"tailscale.com/client/tailscale/apitype"
17
17
-
"tailscale.com/ipn/ipnstate"
18
18
"tailscale.com/tailcfg"
19
19
)
20
20
21
21
type fakeLocalClient struct {
22
22
-
whois func(context.Context, string) (*apitype.WhoIsResponse, error)
23
23
-
status func(context.Context) (*ipnstate.Status, error)
22
22
+
whois func(context.Context, string) (*apitype.WhoIsResponse, error)
24
23
}
25
24
26
25
func (c *fakeLocalClient) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) {
26
26
+
if c.whois == nil {
27
27
+
return nil, errors.New("not implemented")
28
28
+
}
27
29
return c.whois(ctx, remoteAddr)
28
30
}
29
31
30
30
-
func (c *fakeLocalClient) StatusWithoutPeers(ctx context.Context) (*ipnstate.Status, error) {
31
31
-
return c.status(ctx)
32
32
-
}
33
33
-
34
34
-
func TestLocalTailnetHandler(t *testing.T) {
32
32
+
func TestTSHandlers(t *testing.T) {
35
33
t.Parallel()
36
34
35
35
+
logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
36
36
+
37
37
for _, tc := range []struct {
38
38
name string
39
39
whois func(context.Context, string) (*apitype.WhoIsResponse, error)
40
40
-
want int
40
40
+
handler func(*slog.Logger, tailscaleLocalClient, http.Handler) http.Handler
41
41
+
wantNext bool
42
42
+
wantStatus int
41
43
wantHeaders map[string]string
44
44
+
wantBody string
42
45
}{
43
46
{
44
44
-
name: "tailscale whois error",
47
47
+
name: "tailnet: tailscale whois error",
48
48
+
handler: tailnet,
45
49
whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
46
50
return nil, errors.New("whois error")
47
51
},
48
48
-
want: http.StatusInternalServerError,
52
52
+
wantStatus: http.StatusInternalServerError,
53
53
+
wantBody: "Internal Server Error",
49
54
},
50
55
{
51
51
-
name: "tailscale whois no profile",
56
56
+
name: "tailnet: tailscale whois no profile",
57
57
+
handler: tailnet,
52
58
whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
53
53
-
return &apitype.WhoIsResponse{}, nil
59
59
+
return &apitype.WhoIsResponse{Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil
54
60
},
55
55
-
want: http.StatusInternalServerError,
61
61
+
wantStatus: http.StatusInternalServerError,
62
62
+
wantBody: "Internal Server Error",
56
63
},
57
64
{
58
58
-
name: "tailscale whois no node",
65
65
+
name: "tailnet: tailscale whois no node",
66
66
+
handler: tailnet,
59
67
whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
60
68
return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login"}}, nil
61
69
},
62
62
-
want: http.StatusInternalServerError,
70
70
+
wantStatus: http.StatusInternalServerError,
71
71
+
wantBody: "Internal Server Error",
63
72
},
64
73
{
65
65
-
name: "tailscale whois ok (tagged node)",
74
74
+
name: "tailnet: tailscale whois ok (tagged node)",
75
75
+
handler: tailnet,
66
76
whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
67
77
return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "tagged-devices"}, Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil
68
78
},
69
69
-
want: http.StatusOK,
79
79
+
wantNext: true,
80
80
+
wantStatus: http.StatusOK,
81
81
+
wantBody: "OK",
82
82
+
wantHeaders: map[string]string{
83
83
+
"X-Webauth-User": "",
84
84
+
"X-Webauth-Name": "",
85
85
+
},
70
86
},
71
87
{
72
72
-
name: "tailscale whois ok (user)",
88
88
+
name: "tailnet: tailscale whois ok (user)",
89
89
+
handler: tailnet,
73
90
whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
74
91
return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login", DisplayName: "name"}, Node: &tailcfg.Node{Name: "login.ts.net"}}, nil
75
92
},
76
76
-
want: http.StatusOK,
93
93
+
wantNext: true,
94
94
+
wantStatus: http.StatusOK,
95
95
+
wantBody: "OK",
77
96
wantHeaders: map[string]string{
78
97
"X-Webauth-User": "login",
79
98
"X-Webauth-Name": "name",
80
99
},
81
100
},
101
101
+
{
102
102
+
name: "insecure: tailscale whois error",
103
103
+
handler: insecureFunnel,
104
104
+
whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
105
105
+
return nil, errors.New("whois error")
106
106
+
},
107
107
+
wantStatus: http.StatusInternalServerError,
108
108
+
wantBody: "Internal Server Error",
109
109
+
},
110
110
+
{
111
111
+
name: "insecure: tailscale whois no profile",
112
112
+
handler: insecureFunnel,
113
113
+
whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
114
114
+
return &apitype.WhoIsResponse{Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil
115
115
+
},
116
116
+
wantStatus: http.StatusInternalServerError,
117
117
+
wantBody: "Internal Server Error",
118
118
+
},
119
119
+
{
120
120
+
name: "insure: tailscale whois no node",
121
121
+
handler: insecureFunnel,
122
122
+
whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
123
123
+
return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login"}}, nil
124
124
+
},
125
125
+
wantStatus: http.StatusInternalServerError,
126
126
+
wantBody: "Internal Server Error",
127
127
+
},
128
128
+
{
129
129
+
name: "insecure: tagged node",
130
130
+
handler: insecureFunnel,
131
131
+
whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
132
132
+
return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "tagged-devices"}, Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil
133
133
+
},
134
134
+
wantNext: true,
135
135
+
wantStatus: http.StatusOK,
136
136
+
wantBody: "OK",
137
137
+
wantHeaders: map[string]string{
138
138
+
"X-Webauth-User": "",
139
139
+
"X-Webauth-Name": "",
140
140
+
},
141
141
+
},
142
142
+
{
143
143
+
name: "insecure: user node",
144
144
+
handler: insecureFunnel,
145
145
+
whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
146
146
+
return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login", DisplayName: "name"}, Node: &tailcfg.Node{Name: "login.ts.net"}}, nil
147
147
+
},
148
148
+
wantStatus: http.StatusUnauthorized,
149
149
+
wantBody: "Unauthorized",
150
150
+
},
151
151
+
{
152
152
+
name: "oidc: tailscale whois error",
153
153
+
handler: oidcFunnel,
154
154
+
whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
155
155
+
return nil, errors.New("whois error")
156
156
+
},
157
157
+
wantStatus: http.StatusInternalServerError,
158
158
+
wantBody: "Internal Server Error",
159
159
+
},
160
160
+
{
161
161
+
name: "oidc: tailscale whois no profile",
162
162
+
handler: oidcFunnel,
163
163
+
whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
164
164
+
return &apitype.WhoIsResponse{Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil
165
165
+
},
166
166
+
wantStatus: http.StatusInternalServerError,
167
167
+
wantBody: "Internal Server Error",
168
168
+
},
169
169
+
{
170
170
+
name: "oidc: tailscale whois no node",
171
171
+
handler: oidcFunnel,
172
172
+
whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
173
173
+
return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login"}}, nil
174
174
+
},
175
175
+
wantStatus: http.StatusInternalServerError,
176
176
+
wantBody: "Internal Server Error",
177
177
+
},
178
178
+
{
179
179
+
name: "oidc: user node",
180
180
+
handler: oidcFunnel,
181
181
+
whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
182
182
+
return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "login", DisplayName: "name"}, Node: &tailcfg.Node{Name: "login.ts.net"}}, nil
183
183
+
},
184
184
+
wantStatus: http.StatusUnauthorized,
185
185
+
wantBody: "Unauthorized",
186
186
+
},
187
187
+
{
188
188
+
name: "oidc: tagged node",
189
189
+
handler: oidcFunnel,
190
190
+
whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
191
191
+
return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "tagged-devices"}, Node: &tailcfg.Node{Tags: []string{"tag:ingress"}}}, nil
192
192
+
},
193
193
+
wantStatus: http.StatusUnauthorized,
194
194
+
wantBody: "Unauthorized",
195
195
+
},
82
196
} {
83
83
-
tc := tc
84
197
t.Run(tc.name, func(t *testing.T) {
85
198
t.Parallel()
86
86
-
lc := &fakeLocalClient{whois: tc.whois}
87
87
-
be := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
88
88
-
for k, v := range r.Header {
89
89
-
w.Header().Set(k, v[0])
90
90
-
}
91
91
-
fmt.Fprintln(w, "Hi from the backend.")
92
92
-
})
93
93
-
px := httptest.NewServer(localTailnetHandler(slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), lc, be))
94
94
-
defer px.Close()
199
199
+
200
200
+
var nextReq *http.Request
201
201
+
h := tc.handler(logger, &fakeLocalClient{whois: tc.whois}, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
202
202
+
nextReq = r
203
203
+
fmt.Fprintf(w, "OK")
204
204
+
}))
205
205
+
w := httptest.NewRecorder()
206
206
+
h.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "http://example.com/path", nil))
207
207
+
resp := w.Result()
208
208
+
209
209
+
if want, got := tc.wantStatus, resp.StatusCode; want != got {
210
210
+
t.Errorf("want status %d, got: %d", want, got)
211
211
+
}
95
212
96
96
-
resp, err := http.Get(px.URL)
213
213
+
body, err := io.ReadAll(resp.Body)
97
214
if err != nil {
98
215
t.Fatal(err)
99
216
}
100
100
-
defer resp.Body.Close()
101
101
-
102
102
-
if want, got := tc.want, resp.StatusCode; want != got {
103
103
-
t.Errorf("want status %d, got: %d", want, got)
217
217
+
if !strings.Contains(string(body), tc.wantBody) {
218
218
+
t.Errorf("want body %q, got: %q", tc.wantBody, string(body))
104
219
}
105
105
-
if tc.wantHeaders == nil {
106
106
-
tc.wantHeaders = map[string]string{
107
107
-
"X-Webauth-User": "",
108
108
-
"X-Webauth-Name": "",
109
109
-
}
220
220
+
if tc.wantNext && nextReq == nil {
221
221
+
t.Fatalf("next handler not called")
110
222
}
111
223
for k, want := range tc.wantHeaders {
112
112
-
if got := resp.Header.Get(k); got != want {
113
113
-
t.Errorf("want header %s %s, got: %s", k, want, got)
224
224
+
if got := nextReq.Header.Get(k); got != want {
225
225
+
t.Errorf("want header %s = %s, got: %s", k, want, got)
114
226
}
115
227
}
116
228
})
117
229
}
118
230
}
119
231
120
120
-
func TestLocalTailnetTLSHandler(t *testing.T) {
232
232
+
func TestRedirectHandler(t *testing.T) {
121
233
t.Parallel()
122
234
123
123
-
lc := &fakeLocalClient{
124
124
-
whois: func(_ context.Context, _ string) (*apitype.WhoIsResponse, error) {
125
125
-
return &apitype.WhoIsResponse{UserProfile: &tailcfg.UserProfile{LoginName: "tagged-devices"}, Node: &tailcfg.Node{Tags: []string{"foo"}}}, nil
235
235
+
for _, tc := range []struct {
236
236
+
name string
237
237
+
forceSSL bool
238
238
+
fqdn string
239
239
+
request *http.Request
240
240
+
wantNext bool
241
241
+
wantStatus int
242
242
+
wantLocation string
243
243
+
}{
244
244
+
{
245
245
+
name: "forceSSL: redirect",
246
246
+
forceSSL: true,
247
247
+
fqdn: "http://example.com",
248
248
+
request: httptest.NewRequest("", "/path", nil),
249
249
+
wantStatus: http.StatusPermanentRedirect,
250
250
+
wantLocation: "https://example.com/path",
251
251
+
},
252
252
+
{
253
253
+
name: "forceSSL: ok",
254
254
+
forceSSL: true,
255
255
+
fqdn: "example.com",
256
256
+
request: httptest.NewRequest("", "https://example.com/path", nil),
257
257
+
wantNext: true,
258
258
+
wantStatus: http.StatusOK,
259
259
+
},
260
260
+
{
261
261
+
name: "fqdn: redirect",
262
262
+
fqdn: "example.ts.net",
263
263
+
request: httptest.NewRequest("", "https://example/path", nil),
264
264
+
wantStatus: http.StatusPermanentRedirect,
265
265
+
wantLocation: "https://example.ts.net/path",
266
266
+
},
267
267
+
{
268
268
+
name: "fqdn: ok",
269
269
+
fqdn: "example.ts.net",
270
270
+
request: httptest.NewRequest("", "https://example.ts.net/path", nil),
271
271
+
wantNext: true,
272
272
+
wantStatus: http.StatusOK,
126
273
},
127
127
-
status: func(_ context.Context) (*ipnstate.Status, error) {
128
128
-
return &ipnstate.Status{Self: &ipnstate.PeerStatus{DNSName: "foo.ts.net."}}, nil
274
274
+
{
275
275
+
name: "fqdn: ok (not tls)",
276
276
+
fqdn: "example.ts.net",
277
277
+
request: httptest.NewRequest("", "/path", nil),
278
278
+
wantNext: true,
279
279
+
wantStatus: http.StatusOK,
129
280
},
130
130
-
}
131
131
-
be := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
132
132
-
fmt.Fprintln(w, "Hi from the backend.")
133
133
-
})
134
134
-
px := httptest.NewTLSServer(localTailnetTLSHandler(slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), lc, be))
135
135
-
defer px.Close()
281
281
+
} {
282
282
+
t.Run(tc.name, func(t *testing.T) {
283
283
+
t.Parallel()
136
284
137
137
-
cli := px.Client()
138
138
-
cli.CheckRedirect = func(_ *http.Request, _ []*http.Request) error {
139
139
-
return http.ErrUseLastResponse
140
140
-
}
285
285
+
var nextReq *http.Request
286
286
+
h := redirect(tc.fqdn, tc.forceSSL, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
287
287
+
nextReq = r
288
288
+
fmt.Fprintf(w, "OK")
289
289
+
}))
290
290
+
w := httptest.NewRecorder()
291
291
+
h.ServeHTTP(w, tc.request)
292
292
+
resp := w.Result()
141
293
142
142
-
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, px.URL+"/bar", nil)
143
143
-
if err != nil {
144
144
-
t.Fatal(err)
145
145
-
}
146
146
-
resp, err := cli.Do(req)
147
147
-
if err != nil {
148
148
-
t.Fatal(err)
149
149
-
}
150
150
-
defer resp.Body.Close()
151
151
-
if want, got := http.StatusPermanentRedirect, resp.StatusCode; want != got {
152
152
-
t.Fatalf("want status %d, got: %d", want, got)
153
153
-
}
154
154
-
if want, got := "https://foo.ts.net/bar", resp.Header.Get("location"); got != want {
155
155
-
t.Fatalf("want Location %s, got: %s", want, got)
156
156
-
}
294
294
+
if want, got := tc.wantStatus, resp.StatusCode; want != got {
295
295
+
t.Errorf("want status %d, got: %d", want, got)
296
296
+
}
157
297
158
158
-
req, err = http.NewRequestWithContext(t.Context(), http.MethodGet, px.URL, nil)
159
159
-
if err != nil {
160
160
-
t.Fatal(err)
161
161
-
}
162
162
-
req.Host = "foo.ts.net"
163
163
-
resp, err = px.Client().Do(req)
164
164
-
if err != nil {
165
165
-
t.Fatal(err)
166
166
-
}
167
167
-
defer resp.Body.Close()
168
168
-
if want, got := http.StatusOK, resp.StatusCode; want != got {
169
169
-
t.Fatalf("want status %d, got: %d", want, got)
298
298
+
if tc.wantNext && nextReq == nil {
299
299
+
t.Fatalf("next handler not called")
300
300
+
}
301
301
+
if !tc.wantNext && nextReq != nil {
302
302
+
t.Fatalf("next handler was called")
303
303
+
}
304
304
+
if nextReq != nil {
305
305
+
if want, got := tc.wantLocation, nextReq.Header.Get("Location"); got != want {
306
306
+
t.Errorf("want Location header %s, got: %s", want, got)
307
307
+
}
308
308
+
}
309
309
+
})
170
310
}
171
311
}
172
312