neural network learning rate fractal, following https://arxiv.org/pdf/2402.06184
nn_fractal.py
1
2# Based on a paper by Jascha Sohl-Dickstein
3# https://arxiv.org/abs/2402.06184
4# The boundary of neural network trainability is fractal
5
6# also a blog post: https://sohl-dickstein.github.io/2024/02/12/fractal.html
7
8import math
9
10import torch
11
12from torch.func import vmap, grad
13
14from pyt.lib.spaces import map_space, grid
15from pyt.lib.util import msave
16
17torch.manual_seed(39525)
18
19t_real = torch.float64
20dev = torch.device("cuda:0")
21
22nonlin = "tanh"
23
24network_n = 16 # original paper: 16
25
26training_steps = 100 # original paper: 500-1000
27
28dataset_size = network_n * (network_n + 1)
29
30# alphas: mean field neural network parametrization; reference [9] from Sohl-Dickstein paper:
31# Song Mei, Andrea Montanari, and Phan-Minh Nguyen. A
32# mean field view of the landscape of two-layer neural networks.
33# Proceedings of the National Academy of Sciences, 115(33):
34# E7665–E7671, 2018.
35
36alpha_1 = 1 / network_n
37
38match nonlin:
39 case "tanh": # TODO others
40 sigma = torch.tanh
41 alpha_0 = math.sqrt(2/network_n)
42 case _:
43 alpha_0 = math.sqrt(1/network_n)
44
45def y_pred(x, W_0, W_1):
46 return alpha_1 * W_1 @ sigma(alpha_0 * W_0 @ x)
47
48def calculate_loss(D, W_0, W_1):
49 x, y = D
50 loss = ((y - y_pred(x, W_0, W_1))**2).mean()
51 return (loss, loss)
52
53run_network = grad(calculate_loss, argnums=(1,2), has_aux=True)
54
55dataset_x = torch.randn([network_n, dataset_size], dtype=t_real, device=dev)
56dataset_y = torch.randn((dataset_size,), dtype=t_real, device=dev)
57
58D = (dataset_x, dataset_y)
59
60def train_network(eta_0, eta_1, W_0_init, W_1_init):
61 W_0 = W_0_init.clone()
62 W_1 = W_1_init.clone()
63
64 loss_init = calculate_loss(D, W_0, W_1)[0].clamp(min=1e-8)
65
66 loss_record = []
67
68 for index in range(training_steps):
69 ((grad_W_0, grad_W_1), loss) = run_network(D, W_0, W_1)
70 W_0 = W_0 - grad_W_0 * eta_0
71 W_1 = W_1 - grad_W_1 * eta_1
72 loss_record.append(loss / loss_init)
73
74 return torch.stack(loss_record[-20:]).nan_to_num(nan=1e6, posinf=1e6).mean()
75
76train_many = vmap(train_network, in_dims=(0, 0, None, None))
77
78span = 250, 250
79origin = (span[0] // 2) - 10, (span[1] // 2) - 10
80
81stretch = 1, 1
82
83zooms = []
84
85scale = 4096 * 2
86
87def main():
88 _W_0 = torch.randn([network_n, network_n], dtype=t_real, device=dev)
89 _W_1 = torch.randn([1, network_n], dtype=t_real, device=dev)
90
91 mapping = map_space(origin, span, zooms, stretch, scale)
92 (_, (height,width)) = mapping
93
94 canvas = torch.zeros([height, width], dtype=t_real, device=dev)
95
96 # eta = learning rate
97 etas = grid(mapping).to(dev)
98
99 eta_0 = etas[:,:,1]
100 eta_1 = etas[:,:,0].flip(0)
101
102 cols_per_chunk = 1
103
104 convergence_threshold = 1.0
105
106 last_report = 0
107
108 for col_start in range(0, width, cols_per_chunk):
109 col_end = col_start + cols_per_chunk
110 e0 = eta_0[:, col_start:col_end].reshape(-1)
111 e1 = eta_1[:, col_start:col_end].reshape(-1)
112
113 res = train_many(e0, e1, _W_0, _W_1)
114 res = res.nan_to_num(nan=1e6, posinf=1e6, neginf=-1e6)
115
116 canvas[:, col_start:col_end] = res.reshape(height, cols_per_chunk)
117
118 if col_end > last_report + 64:
119 last_report = col_end
120 c = canvas[:, 0:col_end].clone()
121
122 conv = c < convergence_threshold
123
124 t_conv = 1 - c / convergence_threshold
125 t_div = torch.log1p(c - convergence_threshold) / torch.log1p(torch.tensor(1e6))
126
127 t_conv = (t_conv * conv)
128 t_div = (t_div * ~conv)
129
130 t_conv /= t_conv.max().clamp(min=1e-6)
131 t_div /= t_div.max().clamp(min=1e-6)
132
133 msave(t_conv, f"{run_dir}/conv_{col_start}")
134 msave(t_div, f"{run_dir}/div_{col_start}")
135
136
137 c = canvas[:, 0:col_end].clone()
138
139 conv = c < convergence_threshold
140
141 t_conv = 1 - c / convergence_threshold
142 t_div = torch.log1p(c - convergence_threshold) / torch.log1p(torch.tensor(1e6))
143
144 t_conv = (t_conv * conv)
145 t_div = (t_div * ~conv)
146
147 t_conv /= t_conv.max().clamp(min=1e-6)
148 t_div /= t_div.max().clamp(min=1e-6)
149
150 msave(t_conv, f"{run_dir}/conv_final")
151 msave(t_div, f"{run_dir}/div_final")
152