timm
/

Image Classification
timm
PyTorch
Safetensors
Transformers

Loss exploding to nan

#1
by tony0278611 - opened

Hello,

I tried to use these hyperparameters to replicate the imagenet results. Before this I tested the hparams for vit_wee and the result almost matched exactly. However, when I use these hparams for the vit_little model, the training progress quickly stalls (around epoch 10-15) and then goes up and explodes to nan (around epoch 20-25).

Is this a problem with the hyperparameters, some internal problem within the timm library or could this even be a problem with my (cuda) installation?

Best regards
Tony

I cloned the main branch on commit https://github.com/huggingface/pytorch-image-models/commit/8d41071da68055f128a01070a998712e637c4963

I just now see that timm uses a tagging system instead of dev branches, so I will try again on release 1.0.19. If this issue was known at that time I would still appreciate you linking me to a github issue or pr

PyTorch Image Models org
edited 22 days ago

@tony0278611 little things can make things unstable, and when you're on the edge different runs on different systems or different pytorch versions can blow up or stay okay.

Could try:

  • revise LR down a bit
  • lower grad clipping to 1.0
  • use bfloat16 as the low precision dtype if you're on Ampere or later (--amp-dtype bfloat16)
  • change optimizer from nadamw -> adamw ... nadamw does show a bit more variability in stability
  • change your seed

Thanks a lot for your advice, I will try the amp-dtype first. Am I understanding correctly, that the params are chosen to be on the edge of stability and that nan losses aren't necessarily a sign of something wrong with my setup?

Changing to bfloat16 did do the trick! I guess the lr was not the problem but vanishing gradients

PyTorch Image Models org

Yeah, I find being close(ish) to being unstable at the max LR for the schedule yields good results. But that means you may have to tweak things a bit across varying setups...

Thanks for your help. I ended up tweaking all parameters you mentioned (lowered lr from 8e-4 to 6e-4) and seems to be stable. I will report the end result (training mediumd now instead of little because the cost effectiveness seemed way better).

Is there some information on how you performed the hyperparameter search? In the standard training script, the in1k validation set is used for validation in the training process. Did you split a different validation set to perform the hyperparameter search to avoid overfitting on the validation set?

Specifically your advice to

  • change your seed

confused me a little, since I am trying to look for modifications that perform well under various seeds.

Best regard and thank you once again.

Tony

Edit: it did not explode to nan, but it got stuck on random guessing. I now revised both lr and weight decay to half their original values (4e-4 and 4e-2) and grad_clip to 2.0. I hope this runs smoothly now. I am quite surprised how big the difference is between setups

PyTorch Image Models org
edited 19 days ago

@tony0278611 stange that it's this temperamental, how many GPUs are you running this on? What PyTorch + CUDA version?

I have in the past run into PyTorch versions that ran into issues, especially with F.fsdpa attention (TIMM_FUSED_ATTN=0 in env to disable, though it slows things down and uses more GPU ram) that produced all around bad results.

EDIT: another possible failure point, torchcompile was enabled so possible that's producing bad output ... ``--torchcompile ''` (empty string) would override that to disable

PyTorch Image Models org
edited 19 days ago

FYI I tried the hparams as is on a 2x 3090 system that wasn't running anything... I ran it to epoch 25 and there's no signs of any issues. I was using PyTorch 2.5 as I didn't have a newer ver on this system. I think I had issues with 2.6 before. I've done some runs with 2.7.1 on a variety of models (not this one specifically) and seemed okay.

Possibly a PyTorch regression or issues with the environment / driver?

epoch,train_loss,eval_loss,eval_top1,eval_top5,lr
0,6.845011234283447,6.552849532165527,1.1179999999427794,4.371999997863769,4.0475e-05
1,6.636343002319336,6.023297029418945,3.4020000004577637,11.12400000366211,8.045e-05
2,6.472070693969727,5.687659181518555,5.532000002136231,16.584000005493163,0.00012042500000000001
3,6.3668622970581055,5.363080926971436,8.477999996948242,22.04800001098633,0.00016040000000000002
4,6.255019664764404,5.076617205963135,11.09799999633789,26.896000006103517,0.00020037500000000003
5,6.132077693939209,4.721762015075684,14.303999997558593,32.65000000488281,0.00024035000000000004
6,5.998349189758301,4.397778011779785,18.53400000732422,39.35600001464844,0.00028032500000000005
7,5.871716022491455,4.1015487825012205,22.01600001953125,44.338000004882815,0.00032030000000000003
8,5.734888076782227,3.813698821716309,25.602000009765625,49.134000014648436,0.000360275
9,5.616973400115967,3.5748003886413575,28.482000006103515,53.33199999511719,0.00040025000000000005
10,5.493976593017578,3.3574810093688963,32.04400000488281,57.49200004394531,0.00044022500000000003
11,5.3969340324401855,3.2224467460632322,34.15999999511719,59.84999999023437,0.00048020000000000007
12,5.317069053649902,3.062845891036987,36.322000021972656,62.262000002441404,0.000520175
13,5.2350969314575195,2.922411344909668,38.17800000244141,64.792,0.00056015
14,5.1743645668029785,2.8388223177337646,40.2139999987793,66.51599998046875,0.000600125
15,5.112250804901123,2.731342276535034,41.42200002197266,67.63000000488282,0.0006401
16,5.05386209487915,2.702717731513977,42.84600001464844,69.40000002441407,0.000680075
17,5.01328182220459,2.602892908782959,44.252000002441406,70.55000001953125,0.00072005
18,4.9575581550598145,2.5506915132141113,44.89600001220703,71.29999999023437,0.000760025
19,4.913570404052734,2.474668681640625,46.904000006103516,72.97199998535156,0.0007978101276734673
20,4.8612751960754395,2.4122673700714112,47.698000004882815,73.87800001708985,0.0007975858919432082
21,4.819283485412598,2.3244275868606565,49.10999999633789,75.07000000732423,0.0007973507630487585
22,4.769523620605469,2.2662763243865967,50.12200001953125,76.09600002929687,0.0007971047474362959
23,4.721229553222656,2.193927070159912,51.011999995117186,77.03999999511718,0.0007968478518504626
24,4.67425537109375,2.185399310913086,52.55199999267578,77.89200004394532,0.0007965800833341808
25,4.646318435668945,2.0754157316207884,53.21799999511719,78.8020000390625,0.0007963014492284596

Hi, thank you so much for your help!

Did you run with or without fused Attention? Your call to disable fused Attention was probably the correct one. I ran it again and all signs of instability disappeared. Fused Attention would explain the big difference between Setups, since it can be hardware and version dependent and also does increase the risk of numerical instability.

I appended my setup and my run summaries below.

Best regards
Tony

Edit: I was wondering. Why is it that fused Attention (if I am not mistake, the pytorch feature is still in beta) is controlled via a environment variable? Are there any plans to let hyperparams override the default setting?

My setup looks like this with 2 GPUs

~/src/pytorch-image-models$ python -c "import torch; import subprocess; nvcc = subprocess.run(['nvcc', '--version'], capture_output=True, text=True); nvcc_ver = 'N/A'; \                                                                     > lines = nvcc.stdout.splitlines(); \
> nvcc_ver = next((l for l in lines if 'release' in l), 'N/A'); \
> print(f'PyTorch: {torch.__version__}, CUDA (built): {torch.version.cuda}, CUDA available: {torch.cuda.is_available()}, GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"N/A\"}, nvcc: {nvcc_ver}')"

PyTorch: 2.7.1+cu126, CUDA (built): 12.6, CUDA available: True, GPU: NVIDIA RTX 6000 Ada Generation, nvcc: Cuda compilation tools, release 12.0, V12.0.140

Without it my summary looks like this (note this is for mediumd with reduced lr_base, clip_grad and weight_decay):

epoch,train_loss,eval_loss,eval_top1,eval_top5,lr
0,6.902774810791016,6.891392500457764,0.324,1.33,2.0474999999999996e-05
1,6.902324676513672,6.893279375457763,0.182,0.926,4.0449999999999994e-05
2,6.886255264282227,6.7165656259155275,0.58,2.3379999999809264,6.0425e-05
3,6.798336982727051,6.4984868728637695,1.4320000006103515,4.956,8.039999999999999e-05
4,6.708034038543701,6.1457418740844725,3.0559999994659424,9.692000004272462,0.000100375
5,6.831315994262695,6.6255868756103515,1.036,3.791999999847412,0.00012035
6,6.870147705078125,6.68529593963623,0.7019999999809265,2.761999999923706,0.000140325
7,6.89525842666626,6.875268125152588,0.232,0.93,0.0001603
8,6.878701210021973,6.702481873779297,0.7339999999809265,2.783999999961853,0.000180275
9,6.786198139190674,6.220958825073242,2.5019999994659425,8.072000001525879,0.00020025000000000002
10,6.627785682678223,5.774234102783203,5.409999998168946,15.631999999389649,0.000220225
11,6.46635627746582,5.3786505090332035,8.352000003051758,21.276000001831054,0.0002402
12,6.327280044555664,4.945341114501953,11.466000002441406,27.48600001953125,0.000260175
13,6.189669132232666,4.6377080366516115,14.591999998779297,33.250000004882814,0.00028015
14,6.055519104003906,4.332352998046875,18.3980000012207,39.378000004882814,0.000300125
15,5.9295501708984375,4.008611104888916,22.09600001464844,45.03800002197266,0.0003201
16,5.802860260009766,3.746308698348999,25.625999990234376,49.915999990234376,0.000340075
17,5.655991077423096,3.473963011474609,29.856,54.97399996582031,0.00036005
18,5.545177459716797,3.2533175382995605,33.120000013427735,58.786000053710936,0.000380025
19,5.433896064758301,3.05142001953125,36.594000024414065,62.51600003417969,0.00039890574859981264
20,5.305156707763672,2.8383489492797853,39.868000008544925,66.36599997070313,0.0003987937008521722
21,5.194972038269043,2.674523726425171,42.74800000854492,69.15799999267578,0.0003986762099286792
22,5.098299026489258,2.522437100753784,45.69200000366211,71.78999999511718,0.00039855327905040676
23,5.0292649269104,2.4638070092391966,46.60200002441406,72.62600000976562,0.000398424911587567
24,5.16062068939209,2.4736761631774904,46.524000014648436,72.66999999023437,0.00039829111105941866
25,4.9724273681640625,2.332351258468628,49.254000009765626,75.14799998046875,0.00039815188113417085

However with fused attention activated, the same hyperparameters resulted in this:

epoch,train_loss,eval_loss,eval_top1,eval_top5,lr
0,6.902853012084961,6.894402500305175,0.238,1.4619999999427795,2.0474999999999996e-05
1,6.902856826782227,6.892883125,0.184,1.0019999999237061,4.0449999999999994e-05
2,6.8727617263793945,6.690874375915527,0.626,2.694,6.0425e-05
3,6.787080764770508,6.481368438720703,1.527999999847412,5.438000000305176,8.039999999999999e-05
4,6.700096607208252,6.151490935974121,2.883999998931885,9.466,0.000100375
5,6.581276893615723,5.800173048400879,5.340000000610352,15.351999999389648,0.00012035
6,6.771756649017334,6.664580624084473,1.0920000002288819,3.9500000015258787,0.000140325
7,6.879081726074219,6.884776562042236,0.172,0.9,0.0001603
8,6.844717979431152,6.386268439025879,1.9479999999618531,6.121999999847412,0.000180275
9,6.581282138824463,5.605019416809082,5.687999999389649,16.384000003051757,0.00020025000000000002
10,6.422783851623535,5.2869512707519535,8.600000000915527,22.24200000732422,0.000220225
11,6.316596984863281,5.036285956726074,11.209999999084472,26.994,0.0002402
12,6.215880393981934,4.725618184204102,13.776000003662109,31.797999998779297,0.000260175
13,6.1201019287109375,4.492806329040527,16.38000000732422,35.9539999975586,0.00028015
14,6.0147624015808105,4.224120807342529,19.352000017089843,41.35200001708984,0.000300125
15,5.919528484344482,4.005498579330444,22.308000006103516,45.27400001098633,0.0003201
16,5.9355692863464355,3.977339556045532,22.642000013427733,45.521999965820314,0.000340075
17,5.76600456237793,3.6304704219055175,27.532000024414064,51.701999997558595,0.00036005
18,5.675433158874512,3.453242201080322,30.136000014648438,55.02999999023437,0.000380025
19,5.642707347869873,3.345660168457031,31.86000001953125,57.210000014648436,0.00039890574859981264
20,5.569961071014404,3.163454838104248,34.46599999267578,60.40400001708984,0.0003987937008521722
21,5.8800153732299805,4.775262893676758,15.911999992675781,33.832000002441404,0.0003986762099286792
22,6.628257751464844,5.7450297648620605,6.535999997558593,17.284000009765624,0.00039855327905040676
23,6.815000534057617,6.910618126220703,0.118,0.548,0.000398424911587567
24,6.910702228546143,6.909903748931884,0.1,0.5,0.00039829111105941866
25,6.89259147644043,6.809525002746582,0.21,1.074,0.00039815188113417085
26,6.852469444274902,6.6336393734741215,0.552,2.416000002441406,0.0003980072256288824

Now that I found fused attention as a likely culprit, I will disable it and reset the hyperparameters. Would you recommend also resetting to float16 or is bfloat16 generally preferable?

As a sanity check, I ran the training on my setup with the wee hyperparameters again. Besides changing to bfloat16 I made no changes and did not disable fused attn. For that model, there is no issue with stability.

PyTorch Image Models org
edited 18 days ago

Okay, there's also a PyTorch version factor in here, this is an interesting one..

On Pytorch 2.7.1 w/ defaults (so F.sdpa & torch.compile active), it fails

epoch,train_loss,eval_loss,eval_top1,eval_top5,lr
0,6.8745317459106445,6.744654375915528,0.552,2.4539999999427797,4.0475e-05
1,6.8538360595703125,6.632692540588379,0.828,3.124,8.045e-05
2,6.785943508148193,6.446375901489258,1.644,5.405999999961853,0.00012042500000000001
3,6.727376937866211,6.267702362060547,2.617999998779297,8.160000001525878,0.00016040000000000002
4,6.652295112609863,6.054818691101074,3.5839999993896483,10.766000006103516,0.00020037500000000003
5,6.542387008666992,5.731385548400879,5.303999996948242,15.762000000610351,0.00024035000000000004
6,6.442937850952148,5.449951689147949,7.519999999694824,20.190000003051757,0.00028032500000000005
7,6.34382438659668,5.226499792175293,9.49600000366211,23.868,0.00032030000000000003
8,6.252196311950684,4.9809422668457035,11.694000007324219,28.11600001220703,0.000360275
9,6.640519142150879,6.659502597961426,0.6979999999809265,3.0119999999809264,0.00040025000000000005
10,6.823249816894531,6.823515663452149,0.358,1.428,0.00044022500000000003
11,6.88761568069458,6.822595781097412,0.258,1.17,0.00048020000000000007
12,6.8776044845581055,6.8029637496948245,0.28799999996185305,1.2740000003051757,0.000520175
13,nan,nan,0.242,1.106,0.00056015
14,nan,nan,0.1,0.5,0.000600125
15,nan,nan,0.1,0.5,0.0006401

I'm kicking off a 2.7.1 + no F.sdpa and based on your finding I suspect it will work...

EDIT: early epochs on the 2.7.1 w/ no F.sdpa aren't looking amazing so not sure it's just that...

PyTorch Image Models org

@tony0278611 on your other questions, I've been using bfloat16 as the amp dtype more frequently lately as it is a bit better for stability overall, but doubt it will have much impact on this models + hparams vs float16. But yeah with the issue gone I don't see why the original hparams won't work well as my run on PyTorch 2.5 showed, it was tracking well...

PyTorch Image Models org
edited 18 days ago

And FWIW, PyTorch 2.5 w/ F.sdpa enabled will run faster than 2.7.1 with it turned off... should also check Pytorch 2.8 RC and see if it's gone there, might be something to file an issue over (PyTorch repo) if it's still broken...

EDIT: So 2.8 RC is showing similar issues, but it's tracking my 'good' 2.5 run if I disable torchcompile (either edit the yaml or override with --torchcompile ''). Lovely. Torchcompile is showing a larger impact that F.sdpa but my sampling is limited given that it takes awhile to confirm...

@rwightman I'm on the PyTorch Compiler team and I saw your thread on X. Do you have instructions on how to reproduce this? It's OK if it's some end to end training run (we're happy to help investigate).

PyTorch Image Models org

@richardzou5 thanks Richard, already passed that along to you but for anyone else trying to repro

  • original hparams, 2x gpu, torch.compile and F.sdpa used -- ./distributed_train.sh 2 --data-dir /imagenet --config hparams.yaml
  • 2x gpu, torch.compile disabled (empty quotes overrides 'inductor') -- ./distributed_train.sh 2 --data-dir /imagenet --config hparams.yaml --torchcompile ''
  • 2x gpu, F.sdpa disabled -- TIMM_FUSED_ATTN=0 ./distributed_train.sh 2 --data-dir /imagenet --config hparams.yaml

Where hparams.yaml is taken from the file linked in this model repo.

PyTorch Image Models org
edited 18 days ago

I will see if I get a quicker repro on a single GPU w/ a smaller dataset like https://huggingface.co/datasets/timm/mini-imagenet

I will see if I get a quicker repro on a single GPU w/ a smaller dataset like https://huggingface.co/datasets/timm/mini-imagenet

About that: I did try on imagewoof first. The hparams as is had terrible results (~23%), but just changing float16 to bfloat16 changed the results dramatically (~53%). Maybe there is some problem with torchcompile and amp float16

PyTorch Image Models org
edited 17 days ago

I couldn't get a reproduction on mini-imagenet, using essentially same hparams but 100 epochs of a much smaller dataset, couldn't tell much difference between the PT version and use of compile or not :/

But confirmed the issue again with the PyTorch 2.8 release that just came out. Tracking very closely my previous runs,

PyTorch 2.8 w/ F.sdpa active and torch.compile active - NaN by epoch 15, max accuracy 9.4% (this is similar to 2.7.1 w/ torch.compile used)
PyTorch 2.8 w/ F.sdpa active and no torch.compile - 53.65 top-1 by epoch 25 and going steady (very close to PT 2.7.1 w/ w/o torch compile, or PT 2.5)

2.8 release w/ torch.compile

epoch,train_loss,eval_loss,eval_top1,eval_top5,lr
0,6.877491474151611,6.763444609527588,0.568,2.324,4.0475e-05
1,6.889260292053223,6.847562576293945,0.25,1.234,8.045e-05
2,6.896385192871094,6.860354998474121,0.258,1.13,0.00012042500000000001
3,6.89588737487793,6.727006562194824,0.546,2.4559999999809263,0.00016040000000000002
4,6.755687713623047,6.208034316253662,2.777999999694824,8.692000010986328,0.00020037500000000003
5,6.534149646759033,5.608600653381347,6.125999999084472,17.288000000610353,0.00024035000000000004
6,6.3680644035339355,5.24537740737915,9.396000005493164,23.6800000012207,0.00028032500000000005
7,6.634976863861084,6.671964219970703,0.9680000009155273,3.2260000036621093,0.00032030000000000003
8,6.881913185119629,6.80077054611206,0.522,2.084,0.000360275
9,6.819855213165283,6.488913009185791,1.3319999990081788,4.693999998474121,0.00040025000000000005
10,6.7705793380737305,6.42485011795044,1.7059999984741212,5.895999999389648,0.00044022500000000003
11,6.739638328552246,6.343687911224365,2.259999999847412,7.3039999978637695,0.00048020000000000007
12,6.716366767883301,6.263967693176269,2.5400000004577636,8.108,0.000520175
13,6.745352268218994,6.912474140014648,0.104,0.588,0.00056015
14,6.912156581878662,6.907936562042236,0.104,0.522,0.000600125
15,nan,nan,0.1,0.5,0.0006401

2.8 release w/o torch.compile

epoch,train_loss,eval_loss,eval_top1,eval_top5,lr
0,6.844953536987305,6.55326156463623,1.0479999999809264,4.287999998779297,4.0475e-05
1,6.637157917022705,6.031440936279297,3.3320000021362306,10.786000006713866,8.045e-05
2,6.473331451416016,5.697466504058838,5.4359999996948245,16.21399999938965,0.00012042500000000001
3,6.365767478942871,5.369035927276611,8.443999994506836,21.920000004882812,0.00016040000000000002
4,6.255884170532227,5.079038973999023,11.127999998168946,26.73800001464844,0.00020037500000000003
5,6.132992267608643,4.726981287841797,14.195999995117187,32.752000006103515,0.00024035000000000004
6,5.99842643737793,4.4015024378967285,18.376000004882812,39.22400001708984,0.00028032500000000005
7,5.872110366821289,4.13033897857666,21.599999985351563,43.75200001098633,0.00032030000000000003
8,5.737516403198242,3.828389712371826,25.51400001220703,49.06400000610351,0.000360275
9,5.621604919433594,3.5855789527893065,28.17999999511719,52.780000004882815,0.00040025000000000005
10,5.497817039489746,3.361992049484253,32.08199999511719,57.314000014648435,0.00044022500000000003
11,5.401071548461914,3.2331508876800537,33.921999987792965,59.79400001464844,0.00048020000000000007
12,5.320461273193359,3.089002383041382,36.1499999987793,61.938,0.000520175
13,5.237567901611328,2.932105009918213,38.584000006103516,64.48799997070313,0.00056015
14,5.177468776702881,2.8373233769989015,40.17000001220703,66.33400001708985,0.000600125
15,5.114078521728516,2.7447804624938965,41.68400001708984,67.83599998535156,0.0006401
16,5.057980537414551,2.682340735092163,43.06200000854492,69.34600001953125,0.000680075
17,5.0136566162109375,2.591650791435242,44.40200001220703,70.45400001953125,0.00072005
18,4.9518585205078125,2.5238001443481446,45.278,71.72999999023438,0.000760025
19,4.908901691436768,2.4744316375732422,46.81400001586914,72.75799999267578,0.0007978101276734673
20,4.858033180236816,2.3799344438171386,48.05200000366211,74.15200002197265,0.0007975858919432082
21,4.814340591430664,2.317322551383972,49.28200001831055,75.22200000732421,0.0007973507630487585
22,4.763126850128174,2.2650295224761963,50.38199999267578,76.09799998535156,0.0007971047474362959
23,4.714698791503906,2.1731127770233156,51.52800000244141,77.36400002197266,0.0007968478518504626
24,4.666025638580322,2.164455923919678,52.54799998046875,78.26600000732422,0.0007965800833341808
25,4.637571334838867,2.0745864588928224,53.65200000366211,78.6340000390625,0.0007963014492284596

I have to wait until tomorrow before I can start new runs. Can you test PyTorch 2.7 w/o F.sdpa w/ torchcompile and w/ bfloat16?

imagewoof did give me very noticeably unstable results while imagenette didn't seem to be an issue. But unfortunately a stable imagewoof run (after changing to bfloat16) did not translate into a stable in1k run. It seems that the problem needs to be at least a bit challenging for instability to compound.

PyTorch Image Models org

@tony0278611 I had a 2.7 w/o F.sdpa and w/ torch.compile and it was failing. So not sure there's a point in trying bfloat16 as the core of the problem seems to be in something torch.compile related and looking like F.sdpa is not a significant factor

I agree that torch.compile seems to be causing the problems and is main culprit. At the same time it does seem that the combination with float16 and F.sdpa make the problems much worse

PyTorch Image Models org

@tony0278611 this one is a gift that just keeps giving... going down the rabbit hole.

I tried this on my single GPU 4090 machine with same hparams (just one less GPU, and --aug-repeats 0 as that won't work). No issue, and running both w and w/o torch.compile, minimal difference

I tried it on the same dual 3090 machine most of my runs have been on with 1 GPU. It worked fine and matched the 4090 closely enough.

I tried on the 3090 machine, 2x GPU w/ AMP disabled, running in bfloat16 w/ LayerNorm upcast to float32. It does not learn as effectively, so results are lower than AMP in the early stages, but there were no signs of instability until I stopped. And the same results w/ and w/o torch.compile.

So ugh. This one is tricky. I decided to run through a debuging chat with GPT-5 and it looked up several issues. Given what I provided, it feels there could be a subtle view/storage offset issue as it's come up in a number of issues and those can be hard to pinpoint. Also potentially differences in both AMP and DDP use. I'm getting close to throwing in the towel for now though as I'm running out of things to try. Ultimately someone on the torch team needs to dig deep and I'm not sure there's enough urgency here.

Though I did confirm, on another hybrid mobilenet v4/v5 style model, this issue does appear to be messing up my ablations. The original machine I was running on was busy so I ran a comparison on another that had a 2.7.,1 env active and the result was crap. I changed the env back to 2.5 and it's as expected. But in this case it's even more subtle, we're talking 3% points in early training and .5% by the end of the run. So numerical differences that are not catastrophic but just not great. Bleh.

PyTorch Image Models org

And one more datapoint float32 on 2x 3090 w/ torch.compile enabled looks fine, though had to lower the batch size slightly from the default hparams to 192 from 256 per GPU.

Thank you so much for efforts! You confirming my suspicions about something funky helped a lot.

My solution for now is to just roll back to 2.5.1 and to hope that the torch team fixes it.

I was feeling crazy about trying so many hparams and nothing working out. Before I was trying my own architecture and thought I messed up somewhere before trying the baseline only to be confused even more.

Maybe one more thing: I am not 100% sure on this and I might have imagined things: when I disabled amp for a small section of my code and replaced it by a manual fp32 cast, the throughout actually got higher! I might have changed other things so I am not entirely sure.

Your anectode is very valuable, I will always make sure to do ablation studies on the same setup if possible and if not I will first have to verify the results of a nontrivial baseline

PyTorch Image Models org
edited 16 days ago

K, last few things I tried, hoping to narrow things down.

It does appear attention related, if I disable torch.compile for just sdpa part of attention, leave rest of the model alone, things are fine.

Now why it happens with both F.sdpa being used and the eager path wasn't clear to me. It sounds though like recent versions of inductor can possibly use the equivalent of F.sdpa kernels if they pattern match the eager impl appropriately? I'm out of bandwidth for the day as I have some other things to tend to, so didn't have a chance to try and dig into the compiled graph....

I was given a suggestion to disable pattern matching, but that didn't appear to help.

I modified attn to factor out the sdpa into own fn and disable compile on just that, it works in the failure cases...


    `@torch.compiler.disable`
    def _attn(
            self,
            q: torch.Tensor,
            k: torch.Tensor,
            v: torch.Tensor,
            attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:

        if self.fused_attn:
            x = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask=attn_mask,
                dropout_p=self.attn_drop.p if self.training else 0.,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = maybe_add_mask(attn, attn_mask)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v

        return x


    def forward(
            self,
            x: torch.Tensor,
            attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        x = self._attn(q, k, v, attn_mask=attn_mask)

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.norm(x)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

It's a mystery.. the upside of encountering a catastrophic failure is at least we know something is wrong. Silent subtle performance differences surely messed with other ablation studies too.

One thing I thought about: both amp and multi gpu seemed to be problematic in combination with torch.compile. You suspected DDP, but could it be gradient accumulation/averaging when using amp? If you use one GPU but half the physical batch size and same global batch size. Could it also fail?

PyTorch Image Models org

I thought I'd narrowed it down, disabling the compile for the attention portion of forward, but no. I fiddled around with the eager attn impl after that, forcing float32, making other changes that should have altered the behaviour if that was the case and it didn't. So then I tried disabling compile for some other unrelated parts of the model, like MLP forward, global pooling, etc ... and those also toggled good vs bad behaviour. So it seems disabling compile for many different parts of the model will alter the behaviour such that it ends up on the 'good' path and tracks the PT 2.5 or 2.7.1/2.8 w/o compile curves...

Does setting torch._inductor.config.emulate_precision_casts=True help?

I have rolled back to 2.5 for now and don't have the capacity to test it unfortunately

PyTorch Image Models org

@yfengus it did not change things

Sign up or log in to comment