Uff i’m very confused, 'm trying to resolve this part but I can’t find a solution.
All help function function correctly, the particular test was good, but the get_top_convariances function it’s wrong
foo- torch.Size([60, 128, 4]) tensor([[[-3.4697e-01, -6.9493e+00, -3.6928e+00, -6.8480e-01],
[ 3.4767e+00, 5.7545e+00, -2.0498e-01, -7.8584e-01],
[-1.1608e+00, 2.1348e+00, 5.7659e+00, 3.6523e+00],
…,
[ 6.1659e-01, -3.0746e+00, -5.8732e+00, -8.3355e+00],
[-4.1844e+00, 1.5068e+00, 1.7872e+00, 1.9598e+00],
[-1.9612e+00, -1.8472e+00, -2.5687e+00, -3.9748e-01]],
…,
[ 5.8083e+00, 2.9954e-01, -2.0871e+00, -2.6318e+00],
[-3.9213e+00, 5.0618e+00, 2.6743e-01, 4.0105e+00],
[-1.3784e+00, -1.3492e+00, -6.8733e+00, 5.4454e+00]]])
class_chg- torch.Size([60, 128]) tensor([[-6.8480e-01, -7.8584e-01, 3.6523e+00, …, -8.3355e+00,
1.9598e+00, -3.9748e-01],
[ 2.8420e-01, 5.2391e+00, 2.2153e+00, …, 6.3707e+00,
1.1093e+00, 2.4502e+00],
[ 4.8043e-01, 2.1717e+00, 2.7532e-01, …, 7.1343e-01,
-6.4532e+00, -4.3415e+00],
…,
[ 8.8360e+00, 4.5181e+00, -6.0043e-01, …, 9.0040e-01,
-1.3661e+00, 2.5932e+00],
[-1.9945e+00, -2.6196e+00, -4.0520e-03, …, 1.3709e+00,
1.7533e+00, -5.4681e-01],
[-1.3839e+00, -1.8083e+00, 5.7028e+00, …, -2.6318e+00,
4.0105e+00, 5.4454e+00]])
cov_matrix- (128, 128) [[10.31490342 -0.13934841 0.66691834 … 1.45728203 -2.49601633
-1.4768491 ]
[-0.13934841 12.93335031 -1.24922513 … 0.65039519 -0.61105151
-0.54250509]
[ 0.66691834 -1.24922513 8.0322546 … -0.66695757 0.06201199
-2.56580555]
…
[ 1.45728203 0.65039519 -0.66695757 … 11.35423392 0.92914036
-1.6231485 ]
[-2.49601633 -0.61105151 0.06201199 … 0.92914036 12.8996609
-0.18807426]
[-1.4768491 -0.54250509 -2.56580555 … -1.6231485 -0.18807426
12.94835014]]
get_top_magn- (3, 128) [[118 101 19 6 95 108 87 102 80 31 105 126 74 53 94 110 115 24
30 50 13 58 51 121 63 9 12 107 117 120 85 123 48 119 1 42
66 39 103 82 15 109 96 44 34 76 14 71 27 23 124 64 98 92
28 62 4 84 11 52 79 73 36 3 83 25 32 104 22 47 10 86
67 17 81 18 68 8 90 0 112 49 113 125 5 99 29 56 65 61
59 20 122 100 38 88 7 75 16 54 37 26 35 114 40 77 43 93
111 55 57 97 2 46 91 72 21 45 70 89 33 116 78 41 106 60
69 127]
[ 61 101 5 98 2 116 109 118 10 3 67 58 40 45 21 92 127 53
64 85 33 27 77 51 26 47 89 44 105 18 41 114 34 124 1 13
25 52 71 115 28 119 48 9 73 63 120 74 49 87 88 16 59 43
36 125 4 95 31 38 113 54 7 24 96 60 83 99 62 65 102 42
23 68 22 103 11 81 55 97 100 104 106 117 76 17 46 66 108 57
112 90 8 121 20 56 82 91 39 79 6 107 111 14 29 50 94 15
123 12 93 69 84 110 70 78 0 19 35 75 32 37 72 80 122 30
86 126]
[100 81 26 55 109 89 102 4 54 11 112 111 80 32 110 44 61 21
121 96 23 70 105 20 119 7 31 108 65 68 113 57 47 10 15 1
2 38 82 60 49 18 48 122 75 58 93 19 88 115 35 6 126 87
123 73 117 5 118 37 120 56 77 95 46 67 101 85 72 78 76 39
52 45 30 107 9 62 17 42 124 41 12 51 90 0 27 84 28 71
16 106 104 33 127 22 14 98 64 25 91 13 50 94 86 97 29 79
99 43 66 36 53 74 59 63 34 24 69 116 114 83 92 40 8 103
3 125]]