+627
-212
Diff
round #0
+25
-9
go.mod
+25
-9
go.mod
···
4
4
5
5
6
6
7
-
7
+
require (
8
8
github.com/google/go-cmp v0.7.0
9
+
github.com/lstoll/oidc v1.0.0-beta.4.0.20250106123456-6ffce62670fe
9
10
github.com/oklog/run v1.1.0
10
11
github.com/prometheus/client_golang v1.21.1
12
+
github.com/prometheus/common v0.63.0
13
+
github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a
11
14
tailscale.com v1.82.0
12
15
)
13
16
···
51
54
52
55
53
56
57
+
github.com/klauspost/compress v1.17.11 // indirect
58
+
github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a // indirect
59
+
github.com/kylelemons/godebug v1.1.0 // indirect
60
+
github.com/mdlayher/genetlink v1.3.2 // indirect
61
+
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect
62
+
github.com/mdlayher/sdnotify v1.0.0 // indirect
54
63
55
64
56
65
57
66
58
-
59
-
60
-
61
-
62
-
63
-
64
67
github.com/pierrec/lz4/v4 v4.1.21 // indirect
65
68
github.com/prometheus-community/pro-bing v0.4.0 // indirect
66
69
github.com/prometheus/client_model v0.6.1 // indirect
67
-
github.com/prometheus/common v0.63.0 // indirect
68
70
github.com/prometheus/procfs v0.15.1 // indirect
69
71
github.com/safchain/ethtool v0.3.0 // indirect
70
72
github.com/tailscale/certstore v0.1.1-0.20231202035212-d3fa0460f47e // indirect
71
73
github.com/tailscale/go-winio v0.0.0-20231025203758-c4f33415bf55 // indirect
72
74
github.com/tailscale/goupnp v1.0.1-0.20210804011211-c64d0f06ea05 // indirect
73
-
github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a // indirect
74
75
github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7 // indirect
75
76
github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc // indirect
76
77
github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 // indirect
78
+
github.com/tailscale/wireguard-go v0.0.0-20250107165329-0b8b35511f19 // indirect
79
+
github.com/tink-crypto/tink-go/v2 v2.2.0 // indirect
80
+
github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 // indirect
81
+
github.com/vishvananda/netns v0.0.4 // indirect
82
+
github.com/x448/float16 v0.8.4 // indirect
83
+
84
+
85
+
86
+
golang.org/x/exp v0.0.0-20250210185358-939b2ce775ac // indirect
87
+
golang.org/x/mod v0.23.0 // indirect
88
+
golang.org/x/net v0.36.0 // indirect
89
+
golang.org/x/oauth2 v0.26.0 // indirect
90
+
golang.org/x/sync v0.11.0 // indirect
91
+
golang.org/x/sys v0.30.0 // indirect
92
+
golang.org/x/term v0.29.0 // indirect
+303
-101
main.go
+303
-101
main.go
···
16
16
17
17
18
18
19
+
"strings"
20
+
"syscall"
19
21
20
-
21
-
22
-
23
-
24
-
22
+
"github.com/lstoll/oidc"
23
+
"github.com/lstoll/oidc/middleware"
24
+
"github.com/oklog/run"
25
+
"github.com/prometheus/client_golang/prometheus"
26
+
versioncollector "github.com/prometheus/client_golang/prometheus/collectors/version"
25
27
"github.com/prometheus/client_golang/prometheus/promauto"
26
28
"github.com/prometheus/client_golang/prometheus/promhttp"
27
29
"github.com/prometheus/common/version"
30
+
"github.com/tailscale/hujson"
28
31
"tailscale.com/client/local"
29
32
"tailscale.com/client/tailscale/apitype"
33
+
"tailscale.com/ipn/ipnstate"
30
34
"tailscale.com/tsnet"
35
+
tslogger "tailscale.com/types/logger"
36
+
)
31
37
32
38
33
39
···
56
62
57
63
58
64
59
-
60
-
61
65
)
62
66
)
63
67
64
-
type upstreamFlag []upstream
65
-
66
-
func (f *upstreamFlag) String() string {
67
-
return fmt.Sprintf("%+v", *f)
68
+
type upstream struct {
69
+
Name string
70
+
Backend string
71
+
Prometheus bool
72
+
Funnel *funnelConfig `json:"funnel,omitempty"`
68
73
}
69
74
70
-
func (f *upstreamFlag) Set(val string) error {
71
-
up, err := parseUpstreamFlag(val)
72
-
if err != nil {
73
-
return err
74
-
}
75
-
*f = append(*f, up)
76
-
return nil
75
+
type funnelConfig struct {
76
+
Insecure bool
77
+
Issuer string
78
+
ClientID string
79
+
ClientSecret string
77
80
}
78
81
79
-
type upstream struct {
80
-
name string
81
-
backend *url.URL
82
-
prometheus bool
83
-
funnel bool
84
-
}
85
-
86
82
type target struct {
87
83
88
84
89
85
prometheus bool
90
86
}
91
87
92
-
func parseUpstreamFlag(fval string) (upstream, error) {
93
-
k, v, ok := strings.Cut(fval, "=")
94
-
if !ok {
95
-
return upstream{}, errors.New("format: name=http://backend")
96
-
}
97
-
val := strings.Split(v, ";")
98
-
be, err := url.Parse(val[0])
99
-
if err != nil {
100
-
return upstream{}, err
101
-
}
102
-
up := upstream{name: k, backend: be}
103
-
if len(val) > 1 {
104
-
for _, opt := range val[1:] {
105
-
switch opt {
106
-
case "prometheus":
107
-
up.prometheus = true
108
-
case "funnel":
109
-
up.funnel = true
110
-
default:
111
-
return upstream{}, fmt.Errorf("unsupported option: %v", opt)
112
-
}
113
-
}
114
-
}
115
-
return up, nil
116
-
}
117
-
118
88
func main() {
119
89
if err := tsproxy(context.Background()); err != nil {
120
90
fmt.Fprintf(os.Stderr, "tsproxy: %v\n", err)
···
124
94
125
95
func tsproxy(ctx context.Context) error {
126
96
var (
127
-
state = flag.String("state", "", "Optional directory for storing Tailscale state.")
128
-
tslog = flag.Bool("tslog", false, "If true, log Tailscale output.")
129
-
port = flag.Int("port", 32019, "HTTP port for metrics and service discovery.")
130
-
ver = flag.Bool("version", false, "print the version and exit")
97
+
state = flag.String("state", "", "Optional directory for storing Tailscale state.")
98
+
tslog = flag.Bool("tslog", false, "If true, log Tailscale output.")
99
+
port = flag.Int("port", 32019, "HTTP port for metrics and service discovery.")
100
+
ver = flag.Bool("version", false, "print the version and exit")
101
+
upfile = flag.String("upstream", "", "path to upstreams config file")
131
102
)
132
-
var upstreams upstreamFlag
133
-
flag.Var(&upstreams, "upstream", "Repeated for each upstream. Format: name=http://backend:8000")
134
103
flag.Parse()
135
104
136
105
if *ver {
···
138
107
os.Exit(0)
139
108
}
140
109
141
-
if len(upstreams) == 0 {
110
+
if *upfile == "" {
142
111
return fmt.Errorf("required flag missing: upstream")
143
112
}
113
+
114
+
in, err := os.ReadFile(*upfile)
115
+
if err != nil {
116
+
return err
117
+
}
118
+
inJSON, err := hujson.Standardize(in)
119
+
if err != nil {
120
+
return fmt.Errorf("hujson: %w", err)
121
+
}
122
+
var upstreams []upstream
123
+
if err := json.Unmarshal(inJSON, &upstreams); err != nil {
124
+
return fmt.Errorf("json: %w", err)
125
+
}
126
+
if len(upstreams) == 0 {
127
+
return fmt.Errorf("file does not contain any upstreams: %s", *upfile)
128
+
}
129
+
144
130
if *state == "" {
145
131
v, err := os.UserCacheDir()
146
132
if err != nil {
···
212
198
213
199
214
200
201
+
}
215
202
203
+
for i, upstream := range upstreams {
204
+
log := logger.With(slog.String("upstream", upstream.Name))
216
205
217
-
218
-
219
-
i := i
220
-
upstream := upstream
221
-
222
-
log := logger.With(slog.String("upstream", upstream.name))
223
-
224
206
ts := &tsnet.Server{
225
-
Hostname: upstream.name,
226
-
Dir: filepath.Join(*state, "tailscale-"+upstream.name),
207
+
Hostname: upstream.Name,
208
+
Dir: filepath.Join(*state, "tailscale-"+upstream.Name),
227
209
RunWebClient: true,
228
210
}
229
211
defer ts.Close()
···
242
224
243
225
lc, err := ts.LocalClient()
244
226
if err != nil {
245
-
return fmt.Errorf("tailscale: get local client for %s: %w", upstream.name, err)
227
+
return fmt.Errorf("tailscale: get local client for %s: %w", upstream.Name, err)
246
228
}
247
229
248
-
srv := &http.Server{
249
-
TLSConfig: &tls.Config{GetCertificate: lc.GetCertificate},
250
-
Handler: promhttp.InstrumentHandlerInFlight(requestsInFlight.With(prometheus.Labels{"upstream": upstream.name}),
251
-
promhttp.InstrumentHandlerDuration(duration.MustCurryWith(prometheus.Labels{"upstream": upstream.name}),
252
-
promhttp.InstrumentHandlerCounter(requests.MustCurryWith(prometheus.Labels{"upstream": upstream.name}),
253
-
newReverseProxy(log, lc, upstream.backend)))),
230
+
backendURL, err := url.Parse(upstream.Backend)
231
+
if err != nil {
232
+
return fmt.Errorf("upstream %s: parse backend URL: %w", upstream.Name, err)
254
233
}
234
+
// TODO(sr) Instrument proxy.Transport
235
+
proxy := &httputil.ReverseProxy{
236
+
Rewrite: func(req *httputil.ProxyRequest) {
237
+
req.SetURL(backendURL)
238
+
req.SetXForwarded()
239
+
req.Out.Host = req.In.Host
240
+
},
241
+
}
242
+
proxy.ErrorHandler = func(w http.ResponseWriter, _ *http.Request, err error) {
243
+
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
244
+
logger.Error("upstream error", lerr(err))
245
+
}
255
246
256
-
g.Add(func() error {
257
-
st, err := ts.Up(ctx)
258
-
if err != nil {
259
-
return fmt.Errorf("tailscale: wait for node %s to be ready: %w", upstream.name, err)
247
+
instrument := func(h http.Handler) http.Handler {
248
+
return promhttp.InstrumentHandlerInFlight(
249
+
requestsInFlight.With(prometheus.Labels{"upstream": upstream.Name}),
250
+
promhttp.InstrumentHandlerDuration(
251
+
duration.MustCurryWith(prometheus.Labels{"upstream": upstream.Name}),
252
+
promhttp.InstrumentHandlerCounter(
253
+
requests.MustCurryWith(prometheus.Labels{"upstream": upstream.Name}),
254
+
h,
255
+
),
256
+
),
257
+
)
258
+
}
259
+
260
+
{
261
+
var srv *http.Server
262
+
g.Add(func() error {
263
+
st, err := ts.Up(ctx)
264
+
if err != nil {
265
+
return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err)
266
+
}
267
+
268
+
srv = &http.Server{Handler: instrument(localTailnetHandler(log, lc, proxy))}
269
+
ln, err := ts.Listen("tcp", ":80")
270
+
if err != nil {
271
+
return fmt.Errorf("tailscale: listen for %s on port 80: %w", upstream.Name, err)
272
+
}
273
+
274
+
// register in service discovery when we're ready.
275
+
targets[i] = target{name: upstream.Name, prometheus: upstream.Prometheus, magicDNS: st.Self.DNSName}
276
+
277
+
return srv.Serve(ln)
278
+
}, func(_ error) {
279
+
if srv != nil {
280
+
if err := srv.Close(); err != nil {
281
+
log.Error("server shutdown", lerr(err))
282
+
}
283
+
}
284
+
cancel()
285
+
})
286
+
}
287
+
{
288
+
var srv *http.Server
289
+
g.Add(func() error {
290
+
_, err := ts.Up(ctx)
291
+
if err != nil {
292
+
return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err)
293
+
}
294
+
srv = &http.Server{
295
+
TLSConfig: &tls.Config{GetCertificate: lc.GetCertificate},
296
+
Handler: instrument(localTailnetTLSHandler(log, lc, proxy)),
297
+
}
298
+
299
+
ln, err := ts.Listen("tcp", ":443")
300
+
if err != nil {
301
+
return fmt.Errorf("tailscale: listen for %s on port 443: %w", upstream.Name, err)
302
+
}
303
+
return srv.ServeTLS(ln, "", "")
304
+
}, func(_ error) {
305
+
if srv != nil {
306
+
if err := srv.Close(); err != nil {
307
+
log.Error("TLS server shutdown", lerr(err))
308
+
}
309
+
}
310
+
cancel()
311
+
})
312
+
}
313
+
if funnel := upstream.Funnel; funnel != nil {
314
+
if !funnel.Insecure && funnel.Issuer == "" {
315
+
return fmt.Errorf("upstream %s: funnel must set issuer or insecure", upstream.Name)
260
316
}
317
+
{
318
+
var srv *http.Server
319
+
g.Add(func() error {
320
+
_, err := ts.Up(ctx)
321
+
if err != nil {
322
+
return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err)
323
+
}
261
324
262
-
// register in service discovery when we're ready.
263
-
targets[i] = target{name: upstream.name, prometheus: upstream.prometheus, magicDNS: st.Self.DNSName}
325
+
srv = &http.Server{}
326
+
if funnel.Issuer != "" {
327
+
handler, err := oidcFunnelHandler(ctx, log, lc, funnel, proxy)
328
+
if err != nil {
329
+
return fmt.Errorf("oidc: %w", err)
330
+
}
331
+
srv.Handler = handler
332
+
} else if funnel.Insecure {
333
+
srv.Handler = insecureFunnelHandler(log, lc, proxy)
334
+
} else {
335
+
panic("funnel misconfigured")
336
+
}
337
+
srv.Handler = instrument(srv.Handler)
264
338
265
-
ln, err := ts.Listen("tcp", ":80")
266
-
if err != nil {
267
-
return fmt.Errorf("tailscale: listen for %s on port 80: %w", upstream.name, err)
339
+
ln, err := ts.ListenFunnel("tcp", ":443", tsnet.FunnelOnly())
340
+
if err != nil {
341
+
return fmt.Errorf("tailscale: funnel for %s on port 443: %w", upstream.Name, err)
342
+
}
343
+
return srv.Serve(ln)
344
+
}, func(_ error) {
345
+
if srv != nil {
346
+
if err := srv.Close(); err != nil {
347
+
log.Error("TLS server shutdown", lerr(err))
348
+
}
349
+
}
350
+
cancel()
351
+
})
268
352
}
269
-
return srv.Serve(ln)
270
-
}, func(_ error) {
353
+
}
354
+
}
271
355
356
+
return g.Run()
357
+
}
272
358
359
+
// localTailnetHandler serves plain-HTTP on the local tailnet.
360
+
func localTailnetHandler(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler {
361
+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
362
+
whois, err := tsWhoIs(lc, r)
363
+
if err != nil {
364
+
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
365
+
logger.ErrorContext(r.Context(), "tailscale whois", lerr(err))
366
+
return
367
+
}
273
368
369
+
// Proxy requests from tagged nodes as is.
370
+
if whois.Node.IsTagged() {
371
+
next.ServeHTTP(w, r)
372
+
return
373
+
}
274
374
375
+
req := r.Clone(r.Context())
376
+
req.Header.Set("X-Webauth-User", whois.UserProfile.LoginName)
377
+
req.Header.Set("X-Webauth-Name", whois.UserProfile.DisplayName)
378
+
next.ServeHTTP(w, req)
379
+
})
380
+
}
275
381
276
-
g.Add(func() error {
277
-
_, err := ts.Up(ctx)
382
+
// localTailnetTLSHandler serves HTTPS on the local tailnet.
383
+
func localTailnetTLSHandler(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler {
384
+
var (
385
+
handler = localTailnetHandler(logger, lc, next)
386
+
dnsName string
387
+
)
388
+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
389
+
if r.TLS == nil {
390
+
panic("tailnet handler wants tls")
391
+
}
392
+
393
+
if dnsName == "" {
394
+
st, err := lc.StatusWithoutPeers(r.Context())
278
395
if err != nil {
279
-
return fmt.Errorf("tailscale: wait for node %s to be ready: %w", upstream.name, err)
396
+
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
397
+
logger.ErrorContext(r.Context(), "tailscale status", slog.Any("err", err))
398
+
return
280
399
}
400
+
dnsName = strings.TrimSuffix(st.Self.DNSName, ".")
401
+
}
281
402
282
-
if upstream.funnel {
283
-
ln, err := ts.ListenFunnel("tcp", ":443")
284
-
if err != nil {
285
-
return fmt.Errorf("tailscale: funnel for %s on port 443: %w", upstream.name, err)
286
-
}
287
-
return srv.Serve(ln)
288
-
}
403
+
if strings.TrimSuffix(r.Host, ".") != dnsName {
404
+
http.Redirect(w, r, fmt.Sprintf("https://%s%s", dnsName, r.RequestURI), http.StatusPermanentRedirect)
405
+
return
406
+
}
289
407
290
-
ln, err := ts.Listen("tcp", ":443")
291
-
if err != nil {
292
-
return fmt.Errorf("tailscale: listen for %s on port 443: %w", upstream.name, err)
293
-
}
294
-
return srv.ServeTLS(ln, "", "")
295
-
}, func(_ error) {
408
+
handler.ServeHTTP(w, r)
409
+
})
410
+
}
411
+
412
+
// insecureFunnelHandler handles HTTPS requests coming from Tailscale Funnel nodes.
413
+
// This is marked insecure because the upstream is exposed to the public Internet.
414
+
// The upstream is responsible for implementing authentication.
415
+
func insecureFunnelHandler(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler {
416
+
return localTailnetTLSHandler(logger, lc, next)
417
+
}
418
+
419
+
// oidcFunnelHandlers serves Funnel requests, requiring authentication via the configured OIDC issuer.
420
+
func oidcFunnelHandler(ctx context.Context, logger *slog.Logger, lc tailscaleLocalClient, cfg *funnelConfig, next http.Handler) (http.Handler, error) {
421
+
st, err := lc.StatusWithoutPeers(ctx)
422
+
if err != nil {
423
+
return nil, fmt.Errorf("tailscale status: %w", err)
424
+
}
425
+
426
+
redir := &url.URL{Scheme: "https", Path: ".oidc-callback"}
427
+
redir.Host = strings.TrimSuffix(st.Self.DNSName, ".")
428
+
429
+
wrapper, err := middleware.NewFromDiscovery(ctx, nil, cfg.Issuer, cfg.ClientID, cfg.ClientSecret, redir.String())
430
+
if err != nil {
431
+
return nil, fmt.Errorf("oidc middleware: %w", err)
432
+
}
433
+
wrapper.OAuth2Config.Scopes = append(wrapper.OAuth2Config.Scopes, oidc.ScopeProfile)
434
+
435
+
return wrapper.Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
436
+
if r.TLS == nil {
437
+
panic("oidc handler wants tls")
438
+
}
439
+
440
+
_, err := tsWhoIs(lc, r)
441
+
if err != nil {
442
+
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
443
+
logger.ErrorContext(r.Context(), "tailscale whois", lerr(err))
444
+
return
445
+
}
446
+
447
+
tok := middleware.IDJWTFromContext(r.Context())
448
+
if tok == nil {
449
+
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
450
+
logger.ErrorContext(r.Context(), "jwt token missing")
451
+
return
452
+
}
453
+
email, err := tok.StringClaim("email")
454
+
if err != nil {
455
+
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
456
+
logger.ErrorContext(r.Context(), "claim missing", slog.String("claim", "email"))
457
+
return
458
+
}
459
+
name, err := tok.StringClaim("name")
460
+
if err != nil {
461
+
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
462
+
logger.ErrorContext(r.Context(), "claim missing", slog.String("claim", "name"))
463
+
return
464
+
}
465
+
466
+
req := r.Clone(r.Context())
467
+
req.Header.Set("X-Webauth-User", email)
468
+
req.Header.Set("X-Webauth-Name", name)
469
+
470
+
next.ServeHTTP(w, r)
471
+
})), nil
472
+
}
473
+
474
+
type tailscaleLocalClient interface {
475
+
WhoIs(context.Context, string) (*apitype.WhoIsResponse, error)
476
+
StatusWithoutPeers(context.Context) (*ipnstate.Status, error)
477
+
}
478
+
479
+
func tsWhoIs(lc tailscaleLocalClient, r *http.Request) (*apitype.WhoIsResponse, error) {
480
+
whois, err := lc.WhoIs(r.Context(), r.RemoteAddr)
481
+
if err != nil {
482
+
return nil, fmt.Errorf("tailscale whois: %w", err)
483
+
}
484
+
485
+
if whois.Node == nil {
486
+
return nil, errors.New("tailscale whois: node missing")
487
+
}
488
+
489
+
if whois.UserProfile == nil {
490
+
return nil, errors.New("tailscale whois: user profile missing")
491
+
}
492
+
return whois, nil
493
+
}
494
+
495
+
func serveDiscovery(self string, targets []target) http.Handler {
496
+
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
497
+
var tgs []string
+6
go.sum
+6
go.sum
···
117
117
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
118
118
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
119
119
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
120
+
github.com/lstoll/oidc v1.0.0-beta.4.0.20250106123456-6ffce62670fe h1:QBlUtM+Rv9P+W3k9C6+xLgpssfxcKd8Ir+pvNM7E23Y=
121
+
github.com/lstoll/oidc v1.0.0-beta.4.0.20250106123456-6ffce62670fe/go.mod h1:H1Y2Ektfl9aWzSHYT1qf6lXpE9mdil6ZavkI/5+N5Qg=
120
122
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
121
123
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
122
124
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg=
···
183
185
github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e/go.mod h1:orPd6JZXXRyuDusYilywte7k094d7dycXXU5YnWsrwg=
184
186
github.com/tc-hib/winres v0.2.1 h1:YDE0FiP0VmtRaDn7+aaChp1KiF4owBiJa5l964l5ujA=
185
187
github.com/tc-hib/winres v0.2.1/go.mod h1:C/JaNhH3KBvhNKVbvdlDWkbMDO9H4fKKDaN7/07SSuk=
188
+
github.com/tink-crypto/tink-go/v2 v2.2.0 h1:L2Da0F2Udh2agtKztdr69mV/KpnY3/lGTkMgLTVIXlA=
189
+
github.com/tink-crypto/tink-go/v2 v2.2.0/go.mod h1:JJ6PomeNPF3cJpfWC0lgyTES6zpJILkAX0cJNwlS3xU=
186
190
github.com/u-root/u-root v0.12.0 h1:K0AuBFriwr0w/PGS3HawiAw89e3+MU7ks80GpghAsNs=
187
191
github.com/u-root/u-root v0.12.0/go.mod h1:FYjTOh4IkIZHhjsd17lb8nYW6udgXdJhG1c0r6u0arI=
188
192
github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 h1:pyC9PaHYZFgEKFdlp3G8RaCKgVpHZnecvArXvPXcFkM=
···
208
212
golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
209
213
golang.org/x/net v0.36.0 h1:vWF2fRbw4qslQsQzgFqZff+BItCvGFQqKzKIzx1rmoA=
210
214
golang.org/x/net v0.36.0/go.mod h1:bFmbeoIPfrw4sMHNhb4J9f6+tPziuGjq7Jk/38fxi1I=
215
+
golang.org/x/oauth2 v0.26.0 h1:afQXWNNaeC4nvZ0Ed9XvCCzXM6UHJG7iCg0W4fPqSBE=
216
+
golang.org/x/oauth2 v0.26.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
211
217
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
212
218
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
213
219
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
+222
-102
tsproxy.go
+222
-102
tsproxy.go
···
1
1
2
2
3
+
import (
4
+
"context"
5
+
"crypto/subtle"
6
+
"crypto/tls"
7
+
"encoding/json"
8
+
"errors"
3
9
4
10
5
11
12
+
"net"
13
+
"net/http"
14
+
"net/http/httputil"
15
+
"net/netip"
16
+
"net/url"
17
+
"os"
18
+
"path/filepath"
6
19
7
20
8
21
···
16
29
17
30
18
31
19
-
20
-
21
-
22
-
23
-
24
-
25
-
26
-
27
-
28
-
29
-
30
32
"github.com/tailscale/hujson"
31
33
"tailscale.com/client/local"
32
34
"tailscale.com/client/tailscale/apitype"
33
-
"tailscale.com/ipn/ipnstate"
35
+
"tailscale.com/ipn"
34
36
"tailscale.com/tsnet"
35
37
tslogger "tailscale.com/types/logger"
36
38
)
37
39
40
+
// ctxConn is a key to look up a net.Conn stored in an HTTP request's context.
41
+
type ctxConn struct{}
38
42
43
+
var (
44
+
requestsInFlight = promauto.NewGaugeVec(
45
+
prometheus.GaugeOpts{
39
46
40
47
41
48
···
64
71
65
72
66
73
67
-
68
-
69
74
Name string
70
75
Backend string
71
76
Prometheus bool
72
-
Funnel *funnelConfig `json:"funnel,omitempty"`
77
+
Funnel *funnelConfig
73
78
}
74
79
75
80
type funnelConfig struct {
76
81
82
+
Issuer string
83
+
ClientID string
84
+
ClientSecret string
85
+
User string
86
+
Password string
87
+
IP []string
88
+
}
77
89
90
+
type target struct {
78
91
79
92
80
93
···
233
246
234
247
235
248
249
+
}
250
+
proxy.ErrorHandler = func(w http.ResponseWriter, _ *http.Request, err error) {
251
+
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
252
+
log.Error("upstream error", lerr(err))
253
+
}
236
254
255
+
instrument := func(h http.Handler) http.Handler {
237
256
238
257
239
258
···
251
270
252
271
253
272
254
-
255
-
256
-
257
-
258
-
259
-
260
-
261
-
262
-
263
-
264
-
265
273
return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err)
266
274
}
267
275
268
-
srv = &http.Server{Handler: instrument(localTailnetHandler(log, lc, proxy))}
276
+
srv = &http.Server{Handler: instrument(redirect(st.Self.DNSName, false, tailnet(log, lc, proxy)))}
269
277
ln, err := ts.Listen("tcp", ":80")
270
278
if err != nil {
271
279
return fmt.Errorf("tailscale: listen for %s on port 80: %w", upstream.Name, err)
···
287
295
{
288
296
var srv *http.Server
289
297
g.Add(func() error {
290
-
_, err := ts.Up(ctx)
298
+
st, err := ts.Up(ctx)
291
299
if err != nil {
292
300
return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err)
293
301
}
302
+
294
303
srv = &http.Server{
295
304
TLSConfig: &tls.Config{GetCertificate: lc.GetCertificate},
296
-
Handler: instrument(localTailnetTLSHandler(log, lc, proxy)),
305
+
Handler: instrument(redirect(st.Self.DNSName, true, tailnet(log, lc, proxy))),
297
306
}
298
307
299
308
ln, err := ts.Listen("tcp", ":443")
···
304
313
}, func(_ error) {
305
314
if srv != nil {
306
315
if err := srv.Close(); err != nil {
307
-
log.Error("TLS server shutdown", lerr(err))
316
+
log.Error("server shutdown", lerr(err))
308
317
}
309
318
}
310
319
cancel()
311
320
})
312
321
}
313
322
if funnel := upstream.Funnel; funnel != nil {
314
-
if !funnel.Insecure && funnel.Issuer == "" {
315
-
return fmt.Errorf("upstream %s: funnel must set issuer or insecure", upstream.Name)
316
-
}
317
323
{
318
324
var srv *http.Server
319
325
g.Add(func() error {
320
-
_, err := ts.Up(ctx)
326
+
st, err := ts.Up(ctx)
321
327
if err != nil {
322
328
return fmt.Errorf("tailscale: wait for tsnet %s to be ready: %w", upstream.Name, err)
323
329
}
324
330
325
-
srv = &http.Server{}
326
-
if funnel.Issuer != "" {
327
-
handler, err := oidcFunnelHandler(ctx, log, lc, funnel, proxy)
331
+
var handler http.Handler
332
+
switch {
333
+
case funnel.Insecure:
334
+
handler = insecureFunnel(log, lc, proxy)
335
+
case funnel.Issuer != "":
336
+
redir := &url.URL{Scheme: "https", Host: strings.TrimSuffix(st.Self.DNSName, "."), Path: ".oidc-callback"}
337
+
wrapper, err := middleware.NewFromDiscovery(ctx, nil, funnel.Issuer, funnel.ClientID, funnel.ClientSecret, redir.String())
328
338
if err != nil {
329
-
return fmt.Errorf("oidc: %w", err)
339
+
return fmt.Errorf("oidc middleware for %s: %w", upstream.Name, err)
330
340
}
331
-
srv.Handler = handler
332
-
} else if funnel.Insecure {
333
-
srv.Handler = insecureFunnelHandler(log, lc, proxy)
334
-
} else {
335
-
panic("funnel misconfigured")
341
+
wrapper.OAuth2Config.Scopes = append(wrapper.OAuth2Config.Scopes, oidc.ScopeProfile)
342
+
343
+
handler = wrapper.Wrap(oidcFunnel(log, lc, proxy))
344
+
case funnel.User != "":
345
+
handler = insecureFunnel(log, lc, basicAuth(log, funnel.User, funnel.Password, proxy))
346
+
default:
347
+
return fmt.Errorf("upstream %s must set funnel.insecure or funnel.issuer", upstream.Name)
336
348
}
337
-
srv.Handler = instrument(srv.Handler)
338
349
350
+
handler = redirect(st.Self.DNSName, true, handler)
351
+
352
+
if len(funnel.IP) > 0 {
353
+
var allow []netip.Prefix
354
+
for _, ip := range funnel.IP {
355
+
allow = append(allow, netip.MustParsePrefix(ip))
356
+
}
357
+
handler = restrictNetworks(log, allow, handler)
358
+
}
359
+
360
+
srv = &http.Server{
361
+
Handler: instrument(handler),
362
+
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
363
+
return context.WithValue(ctx, ctxConn{}, c)
364
+
},
365
+
}
366
+
339
367
ln, err := ts.ListenFunnel("tcp", ":443", tsnet.FunnelOnly())
340
368
if err != nil {
341
369
···
344
372
}, func(_ error) {
345
373
if srv != nil {
346
374
if err := srv.Close(); err != nil {
347
-
log.Error("TLS server shutdown", lerr(err))
375
+
log.Error("server shutdown", lerr(err))
348
376
}
349
377
}
350
378
cancel()
···
356
384
return g.Run()
357
385
}
358
386
359
-
// localTailnetHandler serves plain-HTTP on the local tailnet.
360
-
func localTailnetHandler(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler {
387
+
func redirect(fqdn string, forceSSL bool, next http.Handler) http.Handler {
388
+
if fqdn == "" {
389
+
panic("redirect: fqdn cannot be empty")
390
+
}
391
+
fqdn = strings.TrimSuffix(fqdn, ".")
361
392
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
393
+
if forceSSL && r.TLS == nil {
394
+
http.Redirect(w, r, fmt.Sprintf("https://%s%s", fqdn, r.RequestURI), http.StatusPermanentRedirect)
395
+
return
396
+
}
397
+
398
+
if r.TLS != nil && strings.TrimSuffix(r.Host, ".") != fqdn {
399
+
http.Redirect(w, r, fmt.Sprintf("https://%s%s", fqdn, r.RequestURI), http.StatusPermanentRedirect)
400
+
return
401
+
}
402
+
next.ServeHTTP(w, r)
403
+
})
404
+
}
405
+
406
+
func basicAuth(logger *slog.Logger, user, password string, next http.Handler) http.Handler {
407
+
if user == "" || password == "" {
408
+
panic("user and password are required")
409
+
}
410
+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
411
+
u, p, ok := r.BasicAuth()
412
+
if ok {
413
+
userCheck := subtle.ConstantTimeCompare([]byte(user), []byte(u))
414
+
passwordCheck := subtle.ConstantTimeCompare([]byte(password), []byte(p))
415
+
if userCheck == 1 && passwordCheck == 1 {
416
+
next.ServeHTTP(w, r)
417
+
return
418
+
}
419
+
}
420
+
logger.ErrorContext(r.Context(), "authentication failed", slog.String("user", u))
421
+
w.Header().Set("WWW-Authenticate", "Basic realm=\"protected\", charset=\"UTF-8\"")
422
+
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
423
+
})
424
+
}
425
+
426
+
func tailnet(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler {
427
+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
362
428
whois, err := tsWhoIs(lc, r)
363
429
if err != nil {
364
430
···
379
445
})
380
446
}
381
447
382
-
// localTailnetTLSHandler serves HTTPS on the local tailnet.
383
-
func localTailnetTLSHandler(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler {
384
-
var (
385
-
handler = localTailnetHandler(logger, lc, next)
386
-
dnsName string
387
-
)
448
+
func insecureFunnel(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler {
388
449
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
389
-
if r.TLS == nil {
390
-
panic("tailnet handler wants tls")
450
+
whois, err := tsWhoIs(lc, r)
451
+
if err != nil {
452
+
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
453
+
logger.ErrorContext(r.Context(), "tailscale whois", lerr(err))
454
+
return
391
455
}
392
-
393
-
if dnsName == "" {
394
-
st, err := lc.StatusWithoutPeers(r.Context())
395
-
if err != nil {
396
-
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
397
-
logger.ErrorContext(r.Context(), "tailscale status", slog.Any("err", err))
398
-
return
399
-
}
400
-
dnsName = strings.TrimSuffix(st.Self.DNSName, ".")
401
-
}
402
-
403
-
if strings.TrimSuffix(r.Host, ".") != dnsName {
404
-
http.Redirect(w, r, fmt.Sprintf("https://%s%s", dnsName, r.RequestURI), http.StatusPermanentRedirect)
456
+
if !whois.Node.IsTagged() {
457
+
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
458
+
logger.ErrorContext(r.Context(), "funnel handler got request from non-tagged node")
405
459
return
406
460
}
407
461
408
-
handler.ServeHTTP(w, r)
462
+
next.ServeHTTP(w, r)
409
463
})
410
464
}
411
465
412
-
// insecureFunnelHandler handles HTTPS requests coming from Tailscale Funnel nodes.
413
-
// This is marked insecure because the upstream is exposed to the public Internet.
414
-
// The upstream is responsible for implementing authentication.
415
-
func insecureFunnelHandler(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler {
416
-
return localTailnetTLSHandler(logger, lc, next)
417
-
}
418
-
419
-
// oidcFunnelHandlers serves Funnel requests, requiring authentication via the configured OIDC issuer.
420
-
func oidcFunnelHandler(ctx context.Context, logger *slog.Logger, lc tailscaleLocalClient, cfg *funnelConfig, next http.Handler) (http.Handler, error) {
421
-
st, err := lc.StatusWithoutPeers(ctx)
422
-
if err != nil {
423
-
return nil, fmt.Errorf("tailscale status: %w", err)
424
-
}
425
-
426
-
redir := &url.URL{Scheme: "https", Path: ".oidc-callback"}
427
-
redir.Host = strings.TrimSuffix(st.Self.DNSName, ".")
428
-
429
-
wrapper, err := middleware.NewFromDiscovery(ctx, nil, cfg.Issuer, cfg.ClientID, cfg.ClientSecret, redir.String())
430
-
if err != nil {
431
-
return nil, fmt.Errorf("oidc middleware: %w", err)
432
-
}
433
-
wrapper.OAuth2Config.Scopes = append(wrapper.OAuth2Config.Scopes, oidc.ScopeProfile)
434
-
435
-
return wrapper.Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
436
-
if r.TLS == nil {
437
-
panic("oidc handler wants tls")
438
-
}
439
-
440
-
_, err := tsWhoIs(lc, r)
466
+
func oidcFunnel(logger *slog.Logger, lc tailscaleLocalClient, next http.Handler) http.Handler {
467
+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
468
+
whois, err := tsWhoIs(lc, r)
441
469
if err != nil {
442
470
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
443
471
logger.ErrorContext(r.Context(), "tailscale whois", lerr(err))
444
472
return
445
473
}
474
+
if !whois.Node.IsTagged() {
475
+
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
476
+
logger.ErrorContext(r.Context(), "funnel handler got request from non-tagged node")
477
+
return
478
+
}
446
479
447
480
tok := middleware.IDJWTFromContext(r.Context())
448
481
if tok == nil {
449
-
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
482
+
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
450
483
logger.ErrorContext(r.Context(), "jwt token missing")
451
484
return
452
485
}
···
467
500
req.Header.Set("X-Webauth-User", email)
468
501
req.Header.Set("X-Webauth-Name", name)
469
502
470
-
next.ServeHTTP(w, r)
471
-
})), nil
503
+
next.ServeHTTP(w, req)
504
+
})
472
505
}
473
506
507
+
// restrictNetworks will only allow clients from the provided IP networks to
508
+
// access the given handler. If skip prefixes are set, paths that match any
509
+
// of the regular expressions will not have restrictions applied.
510
+
func restrictNetworks(logger *slog.Logger, allowedNetworks []netip.Prefix, next http.Handler) http.Handler {
511
+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
512
+
// If the funneled connection is from tsnet, then the net.Conn will be of
513
+
// type ipn.FunnelConn.
514
+
netConn := r.Context().Value(ctxConn{})
515
+
// if the conn is wrapped inside TLS, unwrap it
516
+
if tlsConn, ok := netConn.(*tls.Conn); ok {
517
+
netConn = tlsConn.NetConn()
518
+
}
519
+
var remote netip.AddrPort
520
+
if fconn, ok := netConn.(*ipn.FunnelConn); ok {
521
+
remote = fconn.Src
522
+
} else if v, err := netip.ParseAddrPort(r.RemoteAddr); err == nil {
523
+
remote = v
524
+
} else {
525
+
logger.Error("restrictNetworks: cannot parse client IP:port", lerr(err), slog.String("remote", r.RemoteAddr))
526
+
w.WriteHeader(http.StatusUnauthorized)
527
+
return
528
+
}
529
+
530
+
for _, wl := range allowedNetworks {
531
+
if wl.Contains(remote.Addr()) {
532
+
next.ServeHTTP(w, r)
533
+
return
534
+
}
535
+
}
536
+
537
+
w.WriteHeader(http.StatusForbidden)
538
+
_, _ = fmt.Fprint(w, badNetwork)
539
+
})
540
+
}
541
+
542
+
const badNetwork = `
543
+
<html>
544
+
<head><title>Untrusted network</title></head>
545
+
<body><h1>Access from untrusted networks not permitted</h1></body>
546
+
</html>
547
+
`
548
+
474
549
type tailscaleLocalClient interface {
475
550
WhoIs(context.Context, string) (*apitype.WhoIsResponse, error)
476
-
StatusWithoutPeers(context.Context) (*ipnstate.Status, error)
477
551
}
478
552
479
553
func tsWhoIs(lc tailscaleLocalClient, r *http.Request) (*apitype.WhoIsResponse, error) {
554
+
555
+
556
+
557
+
558
+
559
+
560
+
561
+
562
+
563
+
564
+
565
+
566
+
567
+
568
+
569
+
570
+
571
+
572
+
573
+
574
+
575
+
576
+
577
+
578
+
579
+
580
+
581
+
582
+
583
+
584
+
585
+
586
+
587
+
588
+
589
+
590
+
591
+
592
+
593
+
594
+
595
+
}
596
+
597
+
func lerr(err error) slog.Attr {
598
+
return slog.Any("err", err)
599
+
}
+71
tsproxy_test.go
+71
tsproxy_test.go
···
310
310
}
311
311
}
312
312
313
+
func TestBasicAuthHandler(t *testing.T) {
314
+
t.Parallel()
315
+
316
+
logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
317
+
318
+
for _, tc := range []struct {
319
+
name string
320
+
user string
321
+
password string
322
+
request func(*http.Request)
323
+
wantNext bool
324
+
wantStatus int
325
+
}{
326
+
{
327
+
name: "no basic auth provided",
328
+
user: "admin",
329
+
password: "secret",
330
+
request: func(_ *http.Request) {},
331
+
wantStatus: http.StatusUnauthorized,
332
+
},
333
+
{
334
+
name: "wrong user",
335
+
user: "admin",
336
+
password: "secret",
337
+
request: func(r *http.Request) { r.SetBasicAuth("bad", "secret") },
338
+
wantStatus: http.StatusUnauthorized,
339
+
},
340
+
{
341
+
name: "wrong password",
342
+
user: "admin",
343
+
password: "secret",
344
+
request: func(r *http.Request) { r.SetBasicAuth("admin", "bad") },
345
+
wantStatus: http.StatusUnauthorized,
346
+
},
347
+
{
348
+
name: "ok",
349
+
user: "admin",
350
+
password: "secret",
351
+
request: func(r *http.Request) { r.SetBasicAuth("admin", "secret") },
352
+
wantNext: true,
353
+
wantStatus: http.StatusOK,
354
+
},
355
+
} {
356
+
t.Run(tc.name, func(t *testing.T) {
357
+
t.Parallel()
358
+
359
+
var nextReq *http.Request
360
+
h := basicAuth(logger, tc.user, tc.password, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
361
+
nextReq = r
362
+
fmt.Fprintf(w, "OK")
363
+
}))
364
+
w := httptest.NewRecorder()
365
+
req := httptest.NewRequest("", "/", nil)
366
+
tc.request(req)
367
+
h.ServeHTTP(w, req)
368
+
resp := w.Result()
369
+
370
+
if want, got := tc.wantStatus, resp.StatusCode; want != got {
371
+
t.Errorf("want status %d, got: %d", want, got)
372
+
}
373
+
374
+
if tc.wantNext && nextReq == nil {
375
+
t.Fatalf("next handler not called")
376
+
}
377
+
if !tc.wantNext && nextReq != nil {
378
+
t.Fatalf("next handler should not have been called")
379
+
}
380
+
})
381
+
}
382
+
}
383
+
313
384
func TestServeDiscovery(t *testing.T) {
314
385
t.Parallel()
315
386
History
1 round
0 comments
sr.aux1.dev
submitted
#0
10 commits
expand
collapse
read upstreams from a config file
use go 1.22 loopvar
split server and handler
This is in preparation for adding OIDC.
redirect https to fqdn
implement oidc funnel handler
revamp http handlers implementation and tests
fix panic in lerr when err is nil
add basic auth funnel handler
fix "upstream error" log line to include the upstream
add IP allowlist support for funnel
no conflicts, ready to merge