HTTP reverse proxy for Tailscale

read upstreams from a config file

+50 -134
+2 -2
go.mod
··· 8 8 github.com/google/go-cmp v0.7.0 9 9 github.com/oklog/run v1.1.0 10 10 github.com/prometheus/client_golang v1.21.1 11 + github.com/prometheus/common v0.63.0 12 + github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a 11 13 tailscale.com v1.82.0 12 14 ) 13 15 ··· 64 66 github.com/pierrec/lz4/v4 v4.1.21 // indirect 65 67 github.com/prometheus-community/pro-bing v0.4.0 // indirect 66 68 github.com/prometheus/client_model v0.6.1 // indirect 67 - github.com/prometheus/common v0.63.0 // indirect 68 69 github.com/prometheus/procfs v0.15.1 // indirect 69 70 github.com/safchain/ethtool v0.3.0 // indirect 70 71 github.com/tailscale/certstore v0.1.1-0.20231202035212-d3fa0460f47e // indirect 71 72 github.com/tailscale/go-winio v0.0.0-20231025203758-c4f33415bf55 // indirect 72 73 github.com/tailscale/goupnp v1.0.1-0.20210804011211-c64d0f06ea05 // indirect 73 - github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a // indirect 74 74 github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7 // indirect 75 75 github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc // indirect 76 76 github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 // indirect
+48 -69
main.go
··· 4 4 "context" 5 5 "crypto/tls" 6 6 "encoding/json" 7 - "errors" 8 7 "flag" 9 8 "fmt" 10 9 "log/slog" ··· 16 15 "path/filepath" 17 16 "sort" 18 17 "strconv" 19 - "strings" 20 18 "syscall" 21 19 22 20 "github.com/oklog/run" ··· 25 23 "github.com/prometheus/client_golang/prometheus/promauto" 26 24 "github.com/prometheus/client_golang/prometheus/promhttp" 27 25 "github.com/prometheus/common/version" 26 + "github.com/tailscale/hujson" 28 27 "tailscale.com/client/local" 29 28 "tailscale.com/client/tailscale/apitype" 30 29 "tailscale.com/tsnet" ··· 61 60 ) 62 61 ) 63 62 64 - type upstreamFlag []upstream 65 - 66 - func (f *upstreamFlag) String() string { 67 - return fmt.Sprintf("%+v", *f) 68 - } 69 - 70 - func (f *upstreamFlag) Set(val string) error { 71 - up, err := parseUpstreamFlag(val) 72 - if err != nil { 73 - return err 74 - } 75 - *f = append(*f, up) 76 - return nil 77 - } 78 - 79 63 type upstream struct { 80 - name string 81 - backend *url.URL 82 - prometheus bool 83 - funnel bool 64 + Name string 65 + Backend string 66 + Prometheus bool 67 + Funnel bool 84 68 } 85 69 86 70 type target struct { ··· 89 73 prometheus bool 90 74 } 91 75 92 - func parseUpstreamFlag(fval string) (upstream, error) { 93 - k, v, ok := strings.Cut(fval, "=") 94 - if !ok { 95 - return upstream{}, errors.New("format: name=http://backend") 96 - } 97 - val := strings.Split(v, ";") 98 - be, err := url.Parse(val[0]) 99 - if err != nil { 100 - return upstream{}, err 101 - } 102 - up := upstream{name: k, backend: be} 103 - if len(val) > 1 { 104 - for _, opt := range val[1:] { 105 - switch opt { 106 - case "prometheus": 107 - up.prometheus = true 108 - case "funnel": 109 - up.funnel = true 110 - default: 111 - return upstream{}, fmt.Errorf("unsupported option: %v", opt) 112 - } 113 - } 114 - } 115 - return up, nil 116 - } 117 - 118 76 func main() { 119 77 if err := tsproxy(context.Background()); err != nil { 120 78 fmt.Fprintf(os.Stderr, "tsproxy: %v\n", err) ··· 124 82 125 83 func tsproxy(ctx context.Context) error { 126 84 var ( 127 - state = flag.String("state", "", "Optional directory for storing Tailscale state.") 128 - tslog = flag.Bool("tslog", false, "If true, log Tailscale output.") 129 - port = flag.Int("port", 32019, "HTTP port for metrics and service discovery.") 130 - ver = flag.Bool("version", false, "print the version and exit") 85 + state = flag.String("state", "", "Optional directory for storing Tailscale state.") 86 + tslog = flag.Bool("tslog", false, "If true, log Tailscale output.") 87 + port = flag.Int("port", 32019, "HTTP port for metrics and service discovery.") 88 + ver = flag.Bool("version", false, "print the version and exit") 89 + upfile = flag.String("upstream", "", "path to upstreams config file") 131 90 ) 132 - var upstreams upstreamFlag 133 - flag.Var(&upstreams, "upstream", "Repeated for each upstream. Format: name=http://backend:8000") 134 91 flag.Parse() 135 92 136 93 if *ver { ··· 138 95 os.Exit(0) 139 96 } 140 97 98 + if *upfile == "" { 99 + return fmt.Errorf("required flag missing: upstream") 100 + } 101 + 102 + in, err := os.ReadFile(*upfile) 103 + if err != nil { 104 + return err 105 + } 106 + inJSON, err := hujson.Standardize(in) 107 + if err != nil { 108 + return fmt.Errorf("hujson: %w", err) 109 + } 110 + var upstreams []upstream 111 + if err := json.Unmarshal(inJSON, &upstreams); err != nil { 112 + return fmt.Errorf("json: %w", err) 113 + } 141 114 if len(upstreams) == 0 { 142 - return fmt.Errorf("required flag missing: upstream") 115 + return fmt.Errorf("file does not contain any upstreams: %s", *upfile) 143 116 } 117 + 144 118 if *state == "" { 145 119 v, err := os.UserCacheDir() 146 120 if err != nil { ··· 219 193 i := i 220 194 upstream := upstream 221 195 222 - log := logger.With(slog.String("upstream", upstream.name)) 196 + log := logger.With(slog.String("upstream", upstream.Name)) 223 197 224 198 ts := &tsnet.Server{ 225 - Hostname: upstream.name, 226 - Dir: filepath.Join(*state, "tailscale-"+upstream.name), 199 + Hostname: upstream.Name, 200 + Dir: filepath.Join(*state, "tailscale-"+upstream.Name), 227 201 RunWebClient: true, 228 202 } 229 203 defer ts.Close() ··· 242 216 243 217 lc, err := ts.LocalClient() 244 218 if err != nil { 245 - return fmt.Errorf("tailscale: get local client for %s: %w", upstream.name, err) 219 + return fmt.Errorf("tailscale: get local client for %s: %w", upstream.Name, err) 220 + } 221 + 222 + backendURL, err := url.Parse(upstream.Backend) 223 + if err != nil { 224 + return fmt.Errorf("upstream %s: parse backend URL: %w", upstream.Name, err) 246 225 } 247 226 248 227 srv := &http.Server{ 249 228 TLSConfig: &tls.Config{GetCertificate: lc.GetCertificate}, 250 - Handler: promhttp.InstrumentHandlerInFlight(requestsInFlight.With(prometheus.Labels{"upstream": upstream.name}), 251 - promhttp.InstrumentHandlerDuration(duration.MustCurryWith(prometheus.Labels{"upstream": upstream.name}), 252 - promhttp.InstrumentHandlerCounter(requests.MustCurryWith(prometheus.Labels{"upstream": upstream.name}), 253 - newReverseProxy(log, lc, upstream.backend)))), 229 + Handler: promhttp.InstrumentHandlerInFlight(requestsInFlight.With(prometheus.Labels{"upstream": upstream.Name}), 230 + promhttp.InstrumentHandlerDuration(duration.MustCurryWith(prometheus.Labels{"upstream": upstream.Name}), 231 + promhttp.InstrumentHandlerCounter(requests.MustCurryWith(prometheus.Labels{"upstream": upstream.Name}), 232 + newReverseProxy(log, lc, backendURL)))), 254 233 } 255 234 256 235 g.Add(func() error { 257 236 st, err := ts.Up(ctx) 258 237 if err != nil { 259 - return fmt.Errorf("tailscale: wait for node %s to be ready: %w", upstream.name, err) 238 + return fmt.Errorf("tailscale: wait for node %s to be ready: %w", upstream.Name, err) 260 239 } 261 240 262 241 // register in service discovery when we're ready. 263 - targets[i] = target{name: upstream.name, prometheus: upstream.prometheus, magicDNS: st.Self.DNSName} 242 + targets[i] = target{name: upstream.Name, prometheus: upstream.Prometheus, magicDNS: st.Self.DNSName} 264 243 265 244 ln, err := ts.Listen("tcp", ":80") 266 245 if err != nil { 267 - return fmt.Errorf("tailscale: listen for %s on port 80: %w", upstream.name, err) 246 + return fmt.Errorf("tailscale: listen for %s on port 80: %w", upstream.Name, err) 268 247 } 269 248 return srv.Serve(ln) 270 249 }, func(_ error) { ··· 276 255 g.Add(func() error { 277 256 _, err := ts.Up(ctx) 278 257 if err != nil { 279 - return fmt.Errorf("tailscale: wait for node %s to be ready: %w", upstream.name, err) 258 + return fmt.Errorf("tailscale: wait for node %s to be ready: %w", upstream.Name, err) 280 259 } 281 260 282 - if upstream.funnel { 261 + if upstream.Funnel { 283 262 ln, err := ts.ListenFunnel("tcp", ":443") 284 263 if err != nil { 285 - return fmt.Errorf("tailscale: funnel for %s on port 443: %w", upstream.name, err) 264 + return fmt.Errorf("tailscale: funnel for %s on port 443: %w", upstream.Name, err) 286 265 } 287 266 return srv.Serve(ln) 288 267 } 289 268 290 269 ln, err := ts.Listen("tcp", ":443") 291 270 if err != nil { 292 - return fmt.Errorf("tailscale: listen for %s on port 443: %w", upstream.name, err) 271 + return fmt.Errorf("tailscale: listen for %s on port 443: %w", upstream.Name, err) 293 272 } 294 273 return srv.ServeTLS(ln, "", "") 295 274 }, func(_ error) {
-63
tsproxy_test.go
··· 10 10 "net/http" 11 11 "net/http/httptest" 12 12 "net/url" 13 - "reflect" 14 - "strings" 15 13 "testing" 16 14 17 15 "github.com/google/go-cmp/cmp" ··· 27 25 28 26 func (c *fakeLocalClient) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) { 29 27 return c.whois(ctx, remoteAddr) 30 - } 31 - 32 - func TestParseUpstream(t *testing.T) { 33 - t.Parallel() 34 - 35 - for _, tc := range []struct { 36 - upstream string 37 - want upstream 38 - err error 39 - }{ 40 - { 41 - upstream: "test=http://example.com:-80/", 42 - want: upstream{}, 43 - err: errors.New(`parse "http://`), 44 - }, 45 - { 46 - upstream: "test=http://localhost", 47 - want: upstream{name: "test", backend: mustParseURL("http://localhost")}, 48 - }, 49 - { 50 - upstream: "test=http://localhost;prometheus", 51 - want: upstream{name: "test", backend: mustParseURL("http://localhost"), prometheus: true}, 52 - }, 53 - { 54 - upstream: "test=http://localhost;funnel;prometheus", 55 - want: upstream{name: "test", backend: mustParseURL("http://localhost"), prometheus: true, funnel: true}, 56 - }, 57 - { 58 - upstream: "test=http://localhost;foo", 59 - want: upstream{}, 60 - err: errors.New("unsupported option: foo"), 61 - }, 62 - } { 63 - tc := tc 64 - t.Run(tc.upstream, func(t *testing.T) { 65 - t.Parallel() 66 - up, err := parseUpstreamFlag(tc.upstream) 67 - if tc.err != nil { 68 - if err == nil { 69 - t.Fatalf("want err %v, got nil", tc.err) 70 - } 71 - if !strings.Contains(err.Error(), tc.err.Error()) { 72 - t.Fatalf("want err %v, got %v", tc.err, err) 73 - } 74 - } 75 - if tc.err == nil && err != nil { 76 - t.Fatalf("want no err, got %v", err) 77 - } 78 - if diff := cmp.Diff(tc.want, up, cmp.Exporter(func(_ reflect.Type) bool { return true })); diff != "" { 79 - t.Errorf("mismatch (-want +got):\n%s", diff) 80 - } 81 - }) 82 - } 83 - } 84 - 85 - func mustParseURL(s string) *url.URL { 86 - v, err := url.Parse(s) 87 - if err != nil { 88 - panic(err) 89 - } 90 - return v 91 28 } 92 29 93 30 func TestReverseProxy(t *testing.T) {