diff --git a/CHANGELOG.md b/CHANGELOG.md index 9331cfbe7..5c960173c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,17 @@ +## v5.10.3 (2025-05-19) + +### Fix + +* fix: fix abnormally large llamascope L0. (#483) + +* make llamascope (base model) compatible with sae.fold_activation_norm_scaling_factor + +* make llamascope (base model) compatible with sae.fold_activation_norm_scaling_factor ([`4cafc09`](https://github.com/jbloomAus/SAELens/commit/4cafc09e2bd0c6e265fa2b1234f23fb587b2842a)) + + ## v5.10.2 (2025-05-12) ### Fix diff --git a/check_open_ai_sae_metrics.ipynb b/check_open_ai_sae_metrics.ipynb deleted file mode 100644 index 8b7fef44f..000000000 --- a/check_open_ai_sae_metrics.ipynb +++ /dev/null @@ -1,4338 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "from sae_lens.toolkit.pretrained_saes import load_sparsity\n", - "import plotly.express as px\n", - "\n", - "path = \"open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_0\"\n", - "\n", - "\n", - "sparsity = load_sparsity(path) # [\"sparsity\"]\n", - "\n", - "\n", - "px.histogram(sparsity.cpu(), nbins=100).write_html(\"sparsity_histogram.html\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 57, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "OAI_GPT2Small_v5_32k_resid_delta_mlp\n", - "Loading page (1/2)\n", - "Rendering (2/2) \n", - "Done \n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
 versiond_saelayerkl_div_with_saekl_div_with_ablationce_loss_with_saece_loss_without_saece_loss_with_ablationkl_div_scorece_loss_scorel2_norm_inl2_norm_outl2_ratiol0l1explained_variancemsetotal_tokens_evaluatedfilepath
OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_0/metrics.json53200.0048453.0940833.6054653.5990656.6946490.9984340.99793329.93344929.6015430.98937132.00000071.2111510.96679721.7292926144.000000OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_0/metrics.json
OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_1/metrics.json53210.0066010.0510533.6055963.5990653.6525370.8706940.87786218.97373617.9171680.91064932.00000086.5653310.88544225.6374426144.000000OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_1/metrics.json
OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_2/metrics.json53220.0093690.0587473.6018793.5990653.6459130.8405240.93992249.10653747.6444820.88879831.87500085.8116300.97454737.8372966144.000000OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_2/metrics.json
OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_3/metrics.json53230.0106810.0705923.6096013.5990653.6586780.8486900.82324516.98731815.1572100.87466931.91145985.9382170.78053450.5480586144.000000OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_3/metrics.json
OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_4/metrics.json53240.0126580.0633253.6111593.5990653.6600800.8001110.80178117.25198615.0121790.85254431.95556682.4767070.72949663.7045146144.000000OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_4/metrics.json
OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_5/metrics.json53250.0144670.0685053.6139763.5990653.6693860.7888250.78795018.88896816.2099190.84844032.00000081.4340130.71742287.2817236144.000000OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_5/metrics.json
OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_6/metrics.json53260.0166000.0756943.6187993.5990653.6765160.7807030.74520721.46656418.4024730.85263532.00000078.8297650.706308117.0724956144.000000OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_6/metrics.json
OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_7/metrics.json53270.0170100.0804863.6149763.5990653.6727120.7886630.78395225.44443922.0049900.86248932.00000076.4199370.718003157.7914126144.000000OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_7/metrics.json
OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_8/metrics.json53280.0181030.0873243.6162453.5990653.6803370.7926880.78860630.25022526.3069360.86763732.00000076.7281950.723916219.9829106144.000000OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_8/metrics.json
OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_9/metrics.json53290.0199970.0975893.6174563.5990653.6962450.7950880.81075140.19241335.9458080.88980032.00000072.4265670.742352318.1434336144.000000OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_9/metrics.json
OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_10/metrics.json532100.0231150.1267483.6171723.5990653.7089840.8176290.83526481.75682878.3930890.95536032.00000050.4581150.792657514.5535896144.000000OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_10/metrics.json
OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_11/metrics.json532110.0289530.1738413.6237183.5990653.7833180.8334540.86619792.90629687.6637730.92338132.00000073.9870300.840599742.9575206144.000000OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_11/metrics.json
\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "OAI_GPT2Small_v5_128k_resid_delta_mlp\n", - "Loading page (1/2)\n", - "Rendering (2/2) \n", - "Done \n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
 versiond_saelayerkl_div_with_saekl_div_with_ablationce_loss_with_saece_loss_without_saece_loss_with_ablationkl_div_scorece_loss_scorel2_norm_inl2_norm_outl2_ratiol0l1explained_variancemsetotal_tokens_evaluatedfilepath
OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_0/metrics.json512800.0028833.0940833.6018943.5990656.6946490.9990680.99908629.93344929.7191980.99300631.99935061.1986390.97751913.2761116144.000000OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_0/metrics.json
OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_1/metrics.json512810.0042420.0510533.5998213.5990653.6525370.9169190.98585718.97373618.2053640.93276532.00000084.0601810.91691717.6823756144.000000OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_1/metrics.json
OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_2/metrics.json512820.0067520.0587473.6020343.5990653.6459130.8850700.93661449.10653747.9766850.91246731.98437582.6761400.98109028.2137916144.000000OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_2/metrics.json
OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_3/metrics.json512830.0085320.0705923.6077233.5990653.6586780.8791370.85475616.98731815.5378370.89951831.90511181.8864440.82749438.8149806144.000000OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_3/metrics.json
OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_4/metrics.json512840.0104560.0633253.6118923.5990653.6600800.8348850.78976217.25198615.4340650.87944631.98437578.3716580.78077850.8182036144.000000OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_4/metrics.json
OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_5/metrics.json512850.0116180.0685053.6092593.5990653.6693860.8304110.85502618.88896816.6698510.87376731.98372577.1391600.76537271.5029146144.000000OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_5/metrics.json
OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_6/metrics.json512860.0140230.0756943.6142413.5990653.6765160.8147370.80405821.46656418.8906020.87579531.99837374.6176450.75405297.2212526144.000000OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_6/metrics.json
OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_7/metrics.json512870.0140050.0804863.6098323.5990653.6727120.8259920.85380325.44443922.5591950.88438031.99788572.5513000.764806130.7808236144.000000OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_7/metrics.json
OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_8/metrics.json512880.0149510.0873243.6158983.5990653.6803370.8287900.79287430.25022526.9345550.88843231.99690872.5591280.768030183.6349796144.000000OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_8/metrics.json
OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_9/metrics.json512890.0159190.0975893.6158443.5990653.6962450.8368760.82733340.19241336.5768890.90579232.00000068.9762950.782096267.2160956144.000000OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_9/metrics.json
OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_10/metrics.json5128100.0190770.1267483.6169063.5990653.7089840.8494880.83768781.75682878.8348850.96093332.00000048.3236310.819787443.1088876144.000000OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_10/metrics.json
OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_11/metrics.json5128110.0240720.1738413.6202303.5990653.7833180.8615310.88513092.90629688.2900310.93211932.00000071.2552190.861334635.4594736144.000000OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_11/metrics.json
\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "OAI_GPT2Small_v5_32k_resid_delta_attn\n", - "Loading page (1/2)\n", - "Rendering (2/2) \n", - "Done \n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
 versiond_saelayerkl_div_with_saekl_div_with_ablationce_loss_with_saece_loss_without_saece_loss_with_ablationkl_div_scorece_loss_scorel2_norm_inl2_norm_outl2_ratiol0l1explained_variancemsetotal_tokens_evaluatedfilepath
OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_0/metrics.json53200.0042152.1215283.6037633.5990655.7486010.9980130.99781432.01313831.8915460.99624031.99397942.1469690.9667128.0743876144.000000OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_0/metrics.json
OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_1/metrics.json53210.0018810.0240663.6013153.5990653.6206330.9218210.8956819.7146489.1578540.93791232.00000082.8605580.8680879.5419796144.000000OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_1/metrics.json
OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_2/metrics.json53220.0023420.0310053.6009163.5990653.6266600.9244780.9329128.6418238.0455390.92961932.00000082.8636170.8532339.7406036144.000000OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_2/metrics.json
OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_3/metrics.json53230.0028420.0251343.6023603.5990653.6286610.8869140.8886388.5710127.7537830.90476032.00000081.5633850.81568913.5456966144.000000OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_3/metrics.json
OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_4/metrics.json53240.0037900.0267233.6031803.5990653.6321330.8581690.8755389.1230167.9935710.87779932.00000079.1275330.77235519.5996846144.000000OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_4/metrics.json
OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_5/metrics.json53250.0040550.0313783.6020633.5990653.6277600.8707660.89551010.0343968.8802570.88624332.00000077.7340620.78255024.7396396144.000000OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_5/metrics.json
OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_6/metrics.json53260.0050560.0327893.6043523.5990653.6342860.8457890.84988911.67806610.2101350.87700732.00000074.8570860.75341335.0399096144.000000OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_6/metrics.json
OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_7/metrics.json53270.0048760.0346613.6051603.5990653.6348340.8593360.82958413.65020812.2912880.90228932.00000071.1063690.78382841.6025316144.000000OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_7/metrics.json
OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_8/metrics.json53280.0055580.0293823.6046573.5990653.6258020.8108280.79083116.13794914.4438280.89631832.00000071.6540370.75900957.2114566144.000000OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_8/metrics.json
OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_9/metrics.json53290.0044990.0289183.6015833.5990653.6365010.8444280.93272820.91249819.1393470.91778932.00000065.8290630.78079077.8438266144.000000OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_9/metrics.json
OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_10/metrics.json532100.0039990.0247523.6026783.5990653.6404880.8384410.91277331.82137730.1211300.94575732.00000055.5638810.819294125.3426066144.000000OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_10/metrics.json
OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_11/metrics.json532110.0037720.1068753.6011343.5990653.7308700.9647080.984299280.864441280.5432130.99866831.68750017.1453090.967849180.2710276144.000000OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_11/metrics.json
\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "OAI_GPT2Small_v5_128k_resid_delta_attn\n", - "Loading page (1/2)\n", - "Rendering (2/2) \n", - "Done \n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
 versiond_saelayerkl_div_with_saekl_div_with_ablationce_loss_with_saece_loss_without_saece_loss_with_ablationkl_div_scorece_loss_scorel2_norm_inl2_norm_outl2_ratiol0l1explained_variancemsetotal_tokens_evaluatedfilepath
OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_0/metrics.json512800.0031022.1215283.6000313.5990655.7486010.9985380.99955132.01313831.9105490.99680531.98942239.9674450.9732186.2018256144.000000OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_0/metrics.json
OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_1/metrics.json512810.0017000.0240663.6013793.5990653.6206330.9293770.8926749.7146489.1980020.94232031.99902378.1961980.8797988.5945316144.000000OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_1/metrics.json
OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_2/metrics.json512820.0020000.0310053.5997863.5990653.6266600.9355020.9738478.6418238.0861320.93416631.98714379.2456740.8692678.5918926144.000000OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_2/metrics.json
OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_3/metrics.json512830.0025050.0251343.6013743.5990653.6286610.9003200.9219578.5710127.8543720.91616432.00000077.3448940.83764711.8420296144.000000OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_3/metrics.json
OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_4/metrics.json512840.0034840.0267233.6015783.5990653.6321330.8696230.9240029.1230168.1179540.89141731.99935075.4607850.79816617.3707166144.000000OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_4/metrics.json
OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_5/metrics.json512850.0035560.0313783.6017023.5990653.6277600.8866810.90808110.0343968.9605550.89448931.99837373.3291630.80414822.2161186144.000000OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_5/metrics.json
OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_6/metrics.json512860.0044360.0327893.6031253.5990653.6342860.8647230.88470911.67806610.3489610.88812131.99869970.5204470.77610931.8515176144.000000OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_6/metrics.json
OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_7/metrics.json512870.0043560.0346613.6022553.5990653.6348340.8743310.91079613.65020812.4251060.91156231.99935066.4254460.80324837.5444876144.000000OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_7/metrics.json
OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_8/metrics.json512880.0051400.0293823.6027553.5990653.6258020.8250720.86199016.13794914.5394350.90207532.00000066.7592090.77677452.7791716144.000000OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_8/metrics.json
OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_9/metrics.json512890.0041280.0289183.6019833.5990653.6365010.8572460.92204120.91249819.2526470.92330831.99935060.1971130.80084170.0650336144.000000OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_9/metrics.json
OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_10/metrics.json5128100.0037330.0247523.6015863.5990653.6404880.8491760.93912831.82137730.2704120.95013732.00000052.3072620.836825112.1676256144.000000OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_10/metrics.json
OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_11/metrics.json5128110.0037180.1068753.6006953.5990653.7308700.9652150.987627280.864441280.5576780.99871731.75000020.9497170.969702169.4531406144.000000OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_11/metrics.json
\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "OAI_GPT2Small_v5_32k_resid_post_attn\n", - "Loading page (1/2)\n", - "Rendering (2/2) \n", - "Done \n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
 versiond_saelayerkl_div_with_saekl_div_with_ablationce_loss_with_saece_loss_without_saece_loss_with_ablationkl_div_scorece_loss_scorel2_norm_inl2_norm_outl2_ratiol0l1explained_variancemsetotal_tokens_evaluatedfilepath
OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_0/metrics.json53200.01211412.4802843.6142693.59906515.8619770.9990290.99876032.70796232.5668560.99571531.98925844.1457790.9679728.5118166144.000000OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_0/metrics.json
OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_1/metrics.json53210.01076916.2171043.6080703.59906519.6002660.9993360.99943756.92986756.5532040.99349932.00000050.7360760.96083347.5722586144.000000OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_1/metrics.json
OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_2/metrics.json53220.02015912.8135113.6234663.59906516.3278730.9984270.99808368.90753268.3624500.99114331.99935052.5139540.95719871.1021586144.000000OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_2/metrics.json
OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_3/metrics.json53230.03577210.1018683.6369663.59906513.5488220.9964590.996191103.711441102.8071750.98621731.92578153.4696270.965960125.9992686144.000000OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_3/metrics.json
OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_4/metrics.json53240.05706213.2497133.6685213.59906516.6991040.9956930.994698111.403282109.9318080.97910931.93115255.7391130.951081211.8057566144.000000OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_4/metrics.json
OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_5/metrics.json53250.06962411.5196823.6730683.59906514.8601090.9939560.993428119.651489117.6085050.97318831.87776856.8173480.937413317.2232066144.000000OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_5/metrics.json
OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_6/metrics.json53260.0872176.9332503.6910993.59906510.5226900.9874200.986707128.847931126.0869750.96722031.99528158.8976860.923337455.1238716144.000000OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_6/metrics.json
OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_7/metrics.json53270.1026829.5115233.7080533.59906513.0540410.9892040.988473140.905991137.4845280.96401631.88655756.6595460.910139636.6057746144.000000OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_7/metrics.json
OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_8/metrics.json53280.1236207.8971063.7338983.59906511.4608730.9843460.982850157.343246153.0332640.96111531.99869955.8463970.892514920.3406986144.000000OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_8/metrics.json
OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_9/metrics.json53290.1367155.3963123.7445883.5990658.9704720.9746650.972908181.313721175.9974670.96034231.99365254.2714840.8718731386.9677736144.000000OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_9/metrics.json
OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_10/metrics.json532100.1501316.1930923.7508263.5990659.7542170.9757580.975344224.287598217.7724610.96298732.00000049.5924000.8518672193.0966806144.000000OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_10/metrics.json
OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_11/metrics.json532110.17566113.0875153.7817033.59906516.4848460.9865780.985826395.539520390.6821290.98700831.99007232.0745620.8469843677.2260746144.000000OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_11/metrics.json
\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "OAI_GPT2Small_v5_128k_resid_post_attn\n", - "Loading page (1/2)\n", - "Rendering (2/2) \n", - "Done \n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
 versiond_saelayerkl_div_with_saekl_div_with_ablationce_loss_with_saece_loss_without_saece_loss_with_ablationkl_div_scorece_loss_scorel2_norm_inl2_norm_outl2_ratiol0l1explained_variancemsetotal_tokens_evaluatedfilepath
OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_0/metrics.json512800.00384312.4802843.6034223.59906515.8619770.9996920.99964532.70796232.6073950.99693631.98079544.2473450.9764955.8426856144.000000OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_0/metrics.json
OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_1/metrics.json512810.00673116.2171043.6054623.59906519.6002660.9995850.99960056.92986756.6934930.99586931.99902359.9533540.97226931.1456056144.000000OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_1/metrics.json
OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_2/metrics.json512820.01403512.8135113.6128033.59906516.3278730.9989050.99892168.90753268.5189210.99360231.99935053.2663270.96918649.7326286144.000000OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_2/metrics.json
OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_3/metrics.json512830.02351110.1018683.6227753.59906513.5488220.9976730.997617103.711441103.0269620.98953831.98307454.8256030.97529589.6200796144.000000OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_3/metrics.json
OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_4/metrics.json512840.03753013.2497133.6407053.59906516.6991040.9971670.996821111.403282110.2867740.98400731.94856856.4978290.963923153.8262946144.000000OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_4/metrics.json
OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_5/metrics.json512850.04741111.5196823.6441663.59906514.8601090.9958840.995995119.651489118.0534590.97898031.95800856.1494710.952276238.1262666144.000000OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_5/metrics.json
OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_6/metrics.json512860.0607946.9332503.6552433.59906510.5226900.9912320.991886128.847931126.6881710.97423131.99609456.1194690.940335349.5697636144.000000OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_6/metrics.json
OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_7/metrics.json512870.0735909.5115233.6682113.59906513.0540410.9922630.992687140.905991138.1287540.97083131.99707055.7523610.928668499.7657176144.000000OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_7/metrics.json
OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_8/metrics.json512880.0893687.8971063.6797463.59906511.4608730.9886830.989738157.343246153.8395390.96835231.99837353.4550930.913877732.3178716144.000000OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_8/metrics.json
OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_9/metrics.json512890.0969235.3963123.6957443.5990658.9704720.9820390.982001181.313721176.8293460.96657831.99707051.3939320.8956961123.4250496144.000000OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_9/metrics.json
OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_10/metrics.json5128100.1066456.1930923.6922193.5990659.7542170.9827800.984866224.287598218.7692260.96881331.99804747.4114950.8776161806.1940926144.000000OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_10/metrics.json
OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_11/metrics.json5128110.13666813.0875153.7329623.59906516.4848460.9895570.989609395.539520391.4725040.98925932.00000031.8240550.8704243098.0759286144.000000OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_11/metrics.json
\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "import pandas as pd\n", - "\n", - "# get all json files in all subfolders of the mother path\n", - "import os\n", - "import json\n", - "from IPython.display import display\n", - "import imgkit\n", - "\n", - "\n", - "def get_all_json_files(mother_path):\n", - " json_files = []\n", - "\n", - " for root, dirs, files in os.walk(mother_path):\n", - " for file in files:\n", - " if file.endswith(\"metrics.json\"):\n", - " json_files.append(os.path.join(root, file))\n", - " return json_files\n", - "\n", - "\n", - "def get_benchmark_stats_csv(\n", - " mother_path=\"open_ai_sae_weights_resid_post_attn_reformatted\",\n", - "):\n", - " json_files = get_all_json_files(mother_path)\n", - " eval_metrics = {}\n", - "\n", - " for file in json_files:\n", - " with open(file, \"r\") as f:\n", - " data = json.load(f)\n", - " eval_metrics[file] = data\n", - "\n", - " df = pd.DataFrame(eval_metrics).T\n", - " df[\"filepath\"] = df.index\n", - " df.head()\n", - " pattern = r\".*/v(\\d+)_(\\d+)k_layer_(\\d+)/metrics\\.json\"\n", - "\n", - " df[[\"version\", \"d_sae\", \"layer\"]] = df.filepath.str.extract(pattern)\n", - " # move these columns to the start\n", - " cols = df.columns.tolist()\n", - " cols = cols[-3:] + cols[:-3]\n", - " df = df[cols]\n", - " df[\"layer\"] = df[\"layer\"].astype(int)\n", - "\n", - " # remove \"metrics\" prefix from the columns\n", - " df.columns = [i.replace(\"metrics/\", \"\") for i in df.columns]\n", - " df.sort_values(by=[\"version\", \"d_sae\", \"layer\"], inplace=True)\n", - " df.to_csv(os.path.join(mother_path, \"benchmark_stats.csv\"))\n", - " df.style.background_gradient(cmap=\"viridis\", axis=0).to_html(\n", - " os.path.join(mother_path, \"benchmark_stats.html\")\n", - " )\n", - "\n", - " # read the html\n", - " with open(os.path.join(mother_path, \"benchmark_stats.html\"), \"r\") as f:\n", - " html = f.read()\n", - " imgkit.from_string(html, os.path.join(mother_path, \"benchmark_stats.png\"))\n", - "\n", - " return df.style.background_gradient(cmap=\"viridis\", axis=0)\n", - "\n", - "\n", - "# list all paths that start with OAI in the current fold\n", - "paths = [i for i in os.listdir(\".\") if i.startswith(\"OAI\")]\n", - "tables = []\n", - "for path in paths:\n", - " print(path)\n", - " display(get_benchmark_stats_csv(path))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.7" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/content/dashboard_screenshot.png b/content/dashboard_screenshot.png deleted file mode 100644 index 732217e5f..000000000 Binary files a/content/dashboard_screenshot.png and /dev/null differ diff --git a/content/readme_screenshot_predict_pronoun_feature.png b/content/readme_screenshot_predict_pronoun_feature.png deleted file mode 100644 index 904b1277d..000000000 Binary files a/content/readme_screenshot_predict_pronoun_feature.png and /dev/null differ diff --git a/docs/feature_dashboards.md b/docs/feature_dashboards.md deleted file mode 100644 index 05652a9d8..000000000 --- a/docs/feature_dashboards.md +++ /dev/null @@ -1,8 +0,0 @@ - -## Example Output - -Here's one feature we found in the residual stream of Layer 10 of GPT-2 Small: - -![alt text](../../content/readme_screenshot_predict_pronoun_feature.png). Open `gpt2_resid_pre10_predict_pronoun_feature.html` in your browser to interact with the dashboard (WIP). - -Note, probably this feature could split into more mono-semantic features in a larger SAE that had been trained for longer. (this was was only about 49152 features trained on 10M tokens from OpenWebText). diff --git a/docs/reference.md b/docs/reference.md deleted file mode 100644 index e69de29bb..000000000 diff --git a/eval_metrics_resid_mid_oai.csv b/eval_metrics_resid_mid_oai.csv deleted file mode 100644 index 1221f4c4a..000000000 --- a/eval_metrics_resid_mid_oai.csv +++ /dev/null @@ -1,25 +0,0 @@ -,version,d_sae,layer,metrics/kl_div_with_sae,metrics/kl_div_with_ablation,metrics/ce_loss_with_sae,metrics/ce_loss_without_sae,metrics/ce_loss_with_ablation,metrics/kl_div_score,metrics/ce_loss_score,metrics/l2_norm_in,metrics/l2_norm_out,metrics/l2_ratio,metrics/l0,metrics/l1,metrics/explained_variance,metrics/mse,metrics/total_tokens_evaluated,filepath -open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_0/metrics.json,5,128,0,0.0038433405570685863,12.480283737182617,3.603421926498413,3.599064588546753,15.861976623535156,0.9996920470208848,0.9996446734723997,32.70796203613281,32.60739517211914,0.9969363212585449,31.98079490661621,44.247344970703125,0.9764951467514038,5.842685222625732,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_0/metrics.json -open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_1/metrics.json,5,128,1,0.006731455214321613,16.217103958129883,3.605462074279785,3.599064588546753,19.600265502929688,0.999584916318493,0.999600187150498,56.929866790771484,56.6934928894043,0.9958688616752625,31.9990234375,59.95335388183594,0.9722690582275391,31.145605087280273,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_1/metrics.json -open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_2/metrics.json,5,128,2,0.014034643769264221,12.81351089477539,3.6128032207489014,3.599064588546753,16.32787322998047,0.9989046995874498,0.9989206662941394,68.90753173828125,68.5189208984375,0.9936020374298096,31.99934959411621,53.266326904296875,0.9691864848136902,49.732627868652344,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_2/metrics.json -open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_3/metrics.json,5,128,3,0.023511391133069992,10.10186767578125,3.622774839401245,3.599064588546753,13.548822402954102,0.9976725698764163,0.9976170022128419,103.71144104003906,103.02696228027344,0.9895378351211548,31.983074188232422,54.82560348510742,0.9752952456474304,89.62007904052734,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_3/metrics.json -open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_4/metrics.json,5,128,4,0.03753046691417694,13.249712944030762,3.640705108642578,3.599064588546753,16.69910430908203,0.9971674505649509,0.9968213439818392,111.40328216552734,110.28677368164062,0.9840068817138672,31.94856834411621,56.49782943725586,0.9639231562614441,153.8262939453125,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_4/metrics.json -open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_5/metrics.json,5,128,5,0.04741125553846359,11.519681930541992,3.6441657543182373,3.599064588546753,14.860109329223633,0.9958843259888311,0.9959949394740818,119.6514892578125,118.05345916748047,0.9789802432060242,31.9580078125,56.149471282958984,0.9522756338119507,238.1262664794922,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_5/metrics.json -open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_6/metrics.json,5,128,6,0.06079387292265892,6.933250427246094,3.655242681503296,3.599064588546753,10.522689819335938,0.9912315480941302,0.9918860291994546,128.84793090820312,126.68817138671875,0.9742312431335449,31.99609375,56.119468688964844,0.940334677696228,349.56976318359375,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_6/metrics.json -open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_7/metrics.json,5,128,7,0.07359001040458679,9.511523246765137,3.668210983276367,3.599064588546753,13.054040908813477,0.9922630678078178,0.9926867722998524,140.90599060058594,138.12875366210938,0.9708306789398193,31.9970703125,55.75236129760742,0.928668200969696,499.7657165527344,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_7/metrics.json -open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_8/metrics.json,5,128,8,0.08936788886785507,7.897105693817139,3.679746389389038,3.599064588546753,11.460872650146484,0.9886834629884942,0.9897375005583807,157.34324645996094,153.83953857421875,0.9683517217636108,31.99837303161621,53.45509338378906,0.9138767123222351,732.31787109375,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_8/metrics.json -open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_9/metrics.json,5,128,9,0.09692329168319702,5.3963117599487305,3.6957435607910156,3.599064588546753,8.97047233581543,0.9820389747674404,0.9820011853887981,181.313720703125,176.829345703125,0.966578483581543,31.9970703125,51.3939323425293,0.8956956267356873,1123.425048828125,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_9/metrics.json -open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_10/metrics.json,5,128,10,0.10664453357458115,6.193092346191406,3.692218780517578,3.599064588546753,9.754217147827148,0.98278008341985,0.9848656566878472,224.28759765625,218.76922607421875,0.9688126444816589,31.998046875,47.411495208740234,0.8776161670684814,1806.194091796875,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_10/metrics.json -open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_11/metrics.json,5,128,11,0.1366678923368454,13.087514877319336,3.7329623699188232,3.599064588546753,16.484846115112305,0.9895573839939856,0.9896088738509167,395.5395202636719,391.4725036621094,0.9892591834068298,32.0,31.824054718017578,0.8704243898391724,3098.075927734375,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_11/metrics.json -open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_0/metrics.json,5,32,0,0.012114070355892181,12.480283737182617,3.614269256591797,3.599064588546753,15.861976623535156,0.9990293433538053,0.9987601095072963,32.70796203613281,32.566856384277344,0.9957150220870972,31.9892578125,44.14577865600586,0.967971682548523,8.511816024780273,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_0/metrics.json -open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_1/metrics.json,5,32,1,0.0107693150639534,16.217103958129883,3.608069896697998,3.599064588546753,19.600265502929688,0.9993359285917043,0.9994372104819239,56.929866790771484,56.55320358276367,0.9934985041618347,32.0,50.73607635498047,0.9608334302902222,47.57225799560547,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_1/metrics.json -open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_2/metrics.json,5,32,2,0.02015852928161621,12.81351089477539,3.6234664916992188,3.599064588546753,16.32787322998047,0.9984267754991463,0.9980829389584006,68.90753173828125,68.3624496459961,0.9911429286003113,31.99934959411621,52.513954162597656,0.9571983218193054,71.10215759277344,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_2/metrics.json -open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_3/metrics.json,5,32,3,0.03577210009098053,10.10186767578125,3.63696551322937,3.599064588546753,13.548822402954102,0.9964588627332011,0.9961907691232709,103.71144104003906,102.80717468261719,0.9862173795700073,31.92578125,53.469627380371094,0.9659597277641296,125.999267578125,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_3/metrics.json -open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_4/metrics.json,5,32,4,0.05706218630075455,13.249712944030762,3.668520927429199,3.599064588546753,16.69910430908203,0.9956933265994671,0.9946980054744744,111.40328216552734,109.93180847167969,0.9791091680526733,31.93115234375,55.739112854003906,0.9510807991027832,211.80575561523438,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_4/metrics.json -open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_5/metrics.json,5,32,5,0.06962442398071289,11.519681930541992,3.673067569732666,3.599064588546753,14.860109329223633,0.9939560463213729,0.9934284089185259,119.6514892578125,117.60850524902344,0.9731881022453308,31.87776756286621,56.81734848022461,0.9374126195907593,317.22320556640625,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_5/metrics.json -open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_6/metrics.json,5,32,6,0.08721713721752167,6.933250427246094,3.691099166870117,3.599064588546753,10.522689819335938,0.9874204547877315,0.9867071692566362,128.84793090820312,126.08697509765625,0.9672197699546814,31.995281219482422,58.89768600463867,0.923336923122406,455.1238708496094,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_6/metrics.json -open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_7/metrics.json,5,32,7,0.10268213599920273,9.511523246765137,3.708052635192871,3.599064588546753,13.054040908813477,0.9892044488211575,0.9884729434580918,140.90599060058594,137.48452758789062,0.9640159010887146,31.88655662536621,56.6595458984375,0.910139262676239,636.6057739257812,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_7/metrics.json -open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_8/metrics.json,5,32,8,0.12361977994441986,7.897105693817139,3.7338976860046387,3.599064588546753,11.460872650146484,0.9843461915368303,0.9828496070622145,157.34324645996094,153.03326416015625,0.9611150622367859,31.998699188232422,55.846397399902344,0.8925139307975769,920.3406982421875,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_8/metrics.json -open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_9/metrics.json,5,32,9,0.13671526312828064,5.3963117599487305,3.7445876598358154,3.599064588546753,8.97047233581543,0.974665054724418,0.9729078338238127,181.313720703125,175.99746704101562,0.9603419303894043,31.99365234375,54.271484375,0.8718730807304382,1386.9677734375,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_9/metrics.json -open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_10/metrics.json,5,32,10,0.15013137459754944,6.193092346191406,3.750826358795166,3.599064588546753,9.754217147827148,0.9757582535177477,0.9753439465899841,224.28759765625,217.7724609375,0.962986946105957,32.0,49.59239959716797,0.851866602897644,2193.0966796875,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_10/metrics.json -open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_11/metrics.json,5,32,11,0.17566144466400146,13.087514877319336,3.781702756881714,3.599064588546753,16.484846115112305,0.9865779373463466,0.9858263801882384,395.5395202636719,390.68212890625,0.9870077967643738,31.99007225036621,32.074562072753906,0.846983790397644,3677.22607421875,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_11/metrics.json diff --git a/make_hf_repo.sh b/make_hf_repo.sh deleted file mode 100755 index 9f681e5d4..000000000 --- a/make_hf_repo.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash - -# # Set these variables -# LOCAL_FOLDER="OAI_GPT2Small_v5_32k_resid_delta_attn" -# REPO_NAME="GPT2-Small-OAI-v5-32k-attn-out-SAEs" -# USERNAME="jbloom" - - -# It's actually really easy to upload folders -huggingface-cli repo create GPT2-Small-OAI-v5-128k-attn-out-SAEs -cd OAI_GPT2Small_v5_128k_resid_delta_attn -huggingface-cli upload jbloom/GPT2-Small-OAI-v5-128k-attn-out-SAEs . - -huggingface-cli repo create GPT2-Small-OAI-v5-128k-resid-mid-SAEs -cd OAI_GPT2Small_v5_128k_resid_post_attn -huggingface-cli upload jbloom/GPT2-Small-OAI-v5-128k-resid-mid-SAEs . \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 7b2fba712..db222d6f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "sae-lens" -version = "5.10.2" +version = "5.10.3" description = "Training and Analyzing Sparse Autoencoders (SAEs)" authors = ["Joseph Bloom"] readme = "README.md" diff --git a/sae_lens/__init__.py b/sae_lens/__init__.py index 899252ff1..96e18d6d8 100644 --- a/sae_lens/__init__.py +++ b/sae_lens/__init__.py @@ -1,5 +1,5 @@ # ruff: noqa: E402 -__version__ = "5.10.2" +__version__ = "5.10.3" import logging diff --git a/sae_lens/config.py b/sae_lens/config.py index 6faf8ac9b..c79a9c308 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -99,6 +99,8 @@ class LanguageModelSAERunnerConfig: prepend_bos (bool): Whether to prepend the beginning of sequence token. You should use whatever the model was trained with. jumprelu_init_threshold (float): The threshold to initialize for training JumpReLU SAEs. jumprelu_bandwidth (float): Bandwidth for training JumpReLU SAEs. + fisher_lambda_term (float): The lambda term for the Fisher information matrix. Used for redundant feature removal. + use_fisher (bool): Whether to use the Fisher information matrix for redundant feature removal. autocast (bool): Whether to use autocast during training. Saves vram. autocast_lm (bool): Whether to use autocast during activation fetching. compile_llm (bool): Whether to compile the LLM. @@ -198,6 +200,10 @@ class LanguageModelSAERunnerConfig: jumprelu_init_threshold: float = 0.001 jumprelu_bandwidth: float = 0.001 + # Likelihood Context + fisher_lambda_term: float = 1e-4 + use_fisher: bool = False # whether to use the Fisher information matrix for redudant feature removal + # Performance - see compilation section of lm_runner.py for info autocast: bool = False # autocast to autocast_dtype during training autocast_lm: bool = False # autocast lm during activation fetching @@ -477,6 +483,8 @@ def get_training_sae_cfg_dict(self) -> dict[str, Any]: "jumprelu_init_threshold": self.jumprelu_init_threshold, "jumprelu_bandwidth": self.jumprelu_bandwidth, "scale_sparsity_penalty_by_decoder_norm": self.scale_sparsity_penalty_by_decoder_norm, + "fisher_lambda_term": self.fisher_lambda_term, + "use_fisher": self.use_fisher, } def to_dict(self) -> dict[str, Any]: diff --git a/sae_lens/evals.py b/sae_lens/evals.py index 63d43d203..f41b7bbc4 100644 --- a/sae_lens/evals.py +++ b/sae_lens/evals.py @@ -712,12 +712,15 @@ def single_head_zero_ablate_hook(activations: torch.Tensor, hook: Any): # noqa: ) def kl(original_logits: torch.Tensor, new_logits: torch.Tensor): - original_probs = torch.nn.functional.softmax(original_logits, dim=-1) - log_original_probs = torch.log(original_probs) - new_probs = torch.nn.functional.softmax(new_logits, dim=-1) - log_new_probs = torch.log(new_probs) - kl_div = original_probs * (log_original_probs - log_new_probs) - return kl_div.sum(dim=-1) + # Computes the log-probabilities of the new logits (approximation). + log_probs_new = torch.nn.functional.log_softmax(new_logits, dim=-1) + # Computes the probabilities of the original logits (true distribution). + probs_orig = torch.nn.functional.softmax(original_logits, dim=-1) + # Compute the KL divergence. torch.nn.functional.kl_div expects the first argument to be the log + # probabilities of the approximation (new), and the second argument to be the true distribution + # (original) as probabilities. This computes KL(original || new). + kl = torch.nn.functional.kl_div(log_probs_new, probs_orig, reduction="none") + return kl.sum(dim=-1) if compute_kl: recons_kl_div = kl(original_logits, recons_logits) diff --git a/sae_lens/sae.py b/sae_lens/sae.py index edd873b30..881bc4a4c 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -44,7 +44,7 @@ @dataclass class SAEConfig: # architecture details - architecture: Literal["standard", "gated", "jumprelu", "topk"] + architecture: Literal["standard", "gated", "jumprelu", "topk", "singular_fisher"] # forward pass details. d_in: int @@ -172,6 +172,9 @@ def __init__( elif self.cfg.architecture == "jumprelu": self.initialize_weights_jumprelu() self.encode = self.encode_jumprelu + elif self.cfg.architecture == "singular_fisher": + self.initialize_weights_singular_fisher() + self.encode = self.encode_singular_fisher else: raise ValueError(f"Invalid architecture: {self.cfg.architecture}") @@ -321,6 +324,14 @@ def initialize_weights_jumprelu(self): ) self.initialize_weights_basic() + def initialize_weights_singular_fisher(self): + # The params are identical to the standard SAE + # except we use a threshold parameter too + self.threshold = nn.Parameter( + torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device) + ) + self.initialize_weights_basic() + @overload def to( self: T, @@ -441,6 +452,22 @@ def encode_standard( hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) return self.hook_sae_acts_post(self.activation_fn(hidden_pre)) + + def encode_singular_fisher( + self, x: Float[torch.Tensor, "... d_in"] + ) -> Float[torch.Tensor, "... d_sae"]: + """ + Calculate SAE features from inputs + """ + sae_in = self.process_sae_in(x) + + # "... d_in, d_in d_sae -> ... d_sae", + hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) + + return self.hook_sae_acts_post( + self.activation_fn(hidden_pre) * (hidden_pre > self.threshold) + ) + def process_sae_in( self, sae_in: Float[torch.Tensor, "... d_in"] ) -> Float[torch.Tensor, "... d_sae"]: @@ -475,7 +502,7 @@ def fold_W_dec_norm(self): self.r_mag.data = self.r_mag.data * W_dec_norms.squeeze() self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze() self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze() - elif self.cfg.architecture == "jumprelu": + elif self.cfg.architecture == "jumprelu" or self.cfg.architecture == "singular_fisher": self.threshold.data = self.threshold.data * W_dec_norms.squeeze() self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze() else: diff --git a/sae_lens/toolkit/pretrained_sae_loaders.py b/sae_lens/toolkit/pretrained_sae_loaders.py index 583f0528a..ffb08f353 100644 --- a/sae_lens/toolkit/pretrained_sae_loaders.py +++ b/sae_lens/toolkit/pretrained_sae_loaders.py @@ -473,11 +473,20 @@ def get_llama_scope_config_from_hf( # Model specific parameters model_name, d_in = "meta-llama/Llama-3.1-8B", old_cfg_dict["d_model"] + # Get norm scaling factor to rescale jumprelu threshold. + # We need this because sae.fold_activation_norm_scaling_factor folds scaling norm into W_enc. + # This requires jumprelu threshold to be scaled in the same way + norm_scaling_factor = ( + d_in**0.5 / old_cfg_dict["dataset_average_activation_norm"]["in"] + ) + cfg_dict = { "architecture": "jumprelu", - "jump_relu_threshold": old_cfg_dict["jump_relu_threshold"], + "jump_relu_threshold": old_cfg_dict["jump_relu_threshold"] + * norm_scaling_factor, # We use a scalar jump_relu_threshold for all features # This is different from Gemma Scope JumpReLU SAEs. + # Scaled with norm_scaling_factor to match sae.fold_activation_norm_scaling_factor "d_in": d_in, "d_sae": old_cfg_dict["d_sae"], "dtype": "bfloat16", diff --git a/sae_lens/training/sae_trainer.py b/sae_lens/training/sae_trainer.py index eef087e1a..a14a47f23 100644 --- a/sae_lens/training/sae_trainer.py +++ b/sae_lens/training/sae_trainer.py @@ -193,7 +193,7 @@ def fit(self) -> TrainingSAE: self._checkpoint_if_needed() self.n_training_steps += 1 self._update_pbar(step_output, pbar) - + ### If n_training_tokens > sae_group.cfg.training_tokens, then we should switch to fine-tuning (if we haven't already) self._begin_finetuning_if_needed() diff --git a/sae_lens/training/training_sae.py b/sae_lens/training/training_sae.py index c82ad20aa..7136524e3 100644 --- a/sae_lens/training/training_sae.py +++ b/sae_lens/training/training_sae.py @@ -121,6 +121,8 @@ class TrainingSAEConfig(SAEConfig): decoder_heuristic_init_norm: float init_encoder_as_decoder_transpose: bool scale_sparsity_penalty_by_decoder_norm: bool + fisher_lambda_term: float + use_fisher: bool = False @classmethod def from_sae_runner_config( @@ -163,6 +165,8 @@ def from_sae_runner_config( model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs or {}, jumprelu_init_threshold=cfg.jumprelu_init_threshold, jumprelu_bandwidth=cfg.jumprelu_bandwidth, + fisher_lambda_term=cfg.fisher_lambda_term, + use_fisher=cfg.use_fisher, ) @classmethod @@ -204,6 +208,8 @@ def to_dict(self) -> dict[str, Any]: "normalize_activations": self.normalize_activations, "jumprelu_init_threshold": self.jumprelu_init_threshold, "jumprelu_bandwidth": self.jumprelu_bandwidth, + "fisher_lambda_term": self.fisher_lambda_term, + "use_fisher": self.use_fisher, } # this needs to exist so we can initialize the parent sae cfg without the training specific @@ -258,6 +264,12 @@ def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False): self.log_threshold.data = torch.ones( self.cfg.d_sae, dtype=self.dtype, device=self.device ) * np.log(cfg.jumprelu_init_threshold) + # elif cfg.architecture == "singular_fisher": + # self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre_singular_fisher + # self.bandwidth = cfg.jumprelu_bandwidth + # self.log_threshold.data = torch.ones( + # self.cfg.d_sae, dtype=self.dtype, device=self.device + # ) * np.log(cfg.jumprelu_init_threshold) else: raise ValueError(f"Unknown architecture: {cfg.architecture}") @@ -281,9 +293,16 @@ def initialize_weights_jumprelu(self): ) self.initialize_weights_basic() + def initialize_weights_singular_fisher(self): + # same as the superclass, except we use a log_threshold parameter instead of threshold + self.log_threshold = nn.Parameter( + torch.empty(self.cfg.d_sae, dtype=self.dtype, device=self.device) + ) + self.initialize_weights_basic() + @property def threshold(self) -> torch.Tensor: - if self.cfg.architecture != "jumprelu": + if self.cfg.architecture != "jumprelu" and self.cfg.architecture != "singular_fisher": raise ValueError("Threshold is only defined for Jumprelu SAEs") return torch.exp(self.log_threshold) @@ -324,6 +343,26 @@ def encode_with_hidden_pre_jumprelu( return feature_acts, hidden_pre # type: ignore + def encode_with_hidden_pre_singular_fisher( + self, x: Float[torch.Tensor, "... d_in"] + ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]: + sae_in : Float[torch.Tensor, "... d_in"] = self.process_sae_in(x) + hidden_pre : Float[torch.Tensor, "... d_sae"] = sae_in @ self.W_enc + self.b_enc + + if self.training: + hidden_pre = (hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale) + + # Calculate the Fisher regularization term + self.fisher_reg_term = self.get_fisher_reg_term(hidden_pre) + + + threshold = torch.exp(self.log_threshold) + feature_acts = JumpReLU.apply(hidden_pre, threshold, self.bandwidth) + + return feature_acts, hidden_pre + + + def encode_with_hidden_pre( self, x: Float[torch.Tensor, "... d_in"] ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]: @@ -357,10 +396,14 @@ def encode_with_hidden_pre_gated( ) # magnitude_pre_activation_noised) # Return both the gated feature activations and the magnitude pre-activations - return ( - active_features * feature_magnitudes, - magnitude_pre_activation, - ) # magnitude_pre_activation_noised + # return ( + # active_features * feature_magnitudes, + # magnitude_pre_activation, + # ) # magnitude_pre_activation_noised + out_feature_acts = active_features * feature_magnitudes + return out_feature_acts, magnitude_pre_activation + + def forward( self, @@ -369,6 +412,27 @@ def forward( feature_acts, _ = self.encode_with_hidden_pre_fn(x) return self.decode(feature_acts) + def get_fisher_reg_term(self, feature_acts) -> torch.Tensor: + pseudo_loss = torch.sum(feature_acts**2) + grads = torch.autograd.grad( + pseudo_loss, feature_acts, create_graph=True, retain_graph=True + )[0] + + # Normalize feature activations along the batch (rows) + grads = grads - grads.mean(dim=0, keepdim=True) + grads = grads / (grads.norm(dim=0, keepdim=True) + 1e-8) + + # Compute empirical Fisher matrix + F_empirical = torch.einsum("bi,bj->ij", grads, grads) / feature_acts.shape[0] + + # Remove diagonal (self-correlation) + off_diag = F_empirical - torch.diag(torch.diag(F_empirical)) + + # Penalize off-diagonal (correlation) terms + fisher_reg_term = self.cfg.fisher_lambda_term * torch.sum(off_diag**2) + + return fisher_reg_term + def training_forward_pass( self, sae_in: torch.Tensor, @@ -378,6 +442,8 @@ def training_forward_pass( # do a forward pass to get SAE out, but we also need the # hidden pre. feature_acts, hidden_pre = self.encode_with_hidden_pre_fn(sae_in) + if self.cfg.use_fisher and self.cfg.architecture != "gated": + self.fisher_reg_term = self.get_fisher_reg_term(feature_acts) sae_out = self.decode(feature_acts) # MSE LOSS @@ -407,9 +473,16 @@ def training_forward_pass( aux_reconstruction_loss = torch.sum( (via_gate_reconstruction - sae_in) ** 2, dim=-1 ).mean() + loss = mse_loss + l1_loss + aux_reconstruction_loss + losses["auxiliary_reconstruction_loss"] = aux_reconstruction_loss losses["l1_loss"] = l1_loss + if self.cfg.use_fisher: + self.fisher_reg_term = self.get_fisher_reg_term(feature_acts) + losses["fisher_reg_term"] = self.fisher_reg_term + loss += self.fisher_reg_term + elif self.cfg.architecture == "jumprelu": threshold = torch.exp(self.log_threshold) l0 = torch.sum(Step.apply(hidden_pre, threshold, self.bandwidth), dim=-1) # type: ignore @@ -425,6 +498,13 @@ def training_forward_pass( ) losses["auxiliary_reconstruction_loss"] = topk_loss loss = mse_loss + topk_loss + elif self.cfg.architecture == "singular_fisher": + threshold = torch.exp(self.log_threshold) + l0 = torch.sum(Step.apply(hidden_pre, threshold, self.bandwidth), dim=-1) # type: ignore + l0_loss = (current_l1_coefficient * l0).mean() + loss = mse_loss + l0_loss + self.fisher_reg_term + losses["l0_loss"] = l0_loss + losses["fisher_reg_term"] = self.fisher_reg_term else: # default SAE sparsity loss weighted_feature_acts = feature_acts @@ -453,6 +533,9 @@ def training_forward_pass( losses["l1_loss"] = l1_loss losses["mse_loss"] = mse_loss + if self.cfg.use_fisher and self.cfg.architecture != "gated": + losses["fisher_reg_term"] = self.fisher_reg_term + loss = loss + self.fisher_reg_term return TrainStepOutput( sae_in=sae_in, @@ -553,13 +636,13 @@ def batch_norm_mse_loss_fn( return standard_mse_loss_fn def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None: - if self.cfg.architecture == "jumprelu" and "log_threshold" in state_dict: + if self.cfg.architecture in ["jumprelu", "singular_fisher"] and "log_threshold" in state_dict: threshold = torch.exp(state_dict["log_threshold"]).detach().contiguous() del state_dict["log_threshold"] state_dict["threshold"] = threshold def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None: - if self.cfg.architecture == "jumprelu" and "threshold" in state_dict: + if self.cfg.architecture in ["jumprelu", "singular_fisher"] and "threshold" in state_dict: threshold = state_dict["threshold"] del state_dict["threshold"] state_dict["log_threshold"] = torch.log(threshold).detach().contiguous() @@ -627,7 +710,7 @@ def initialize_weights_complex(self): @torch.no_grad() def fold_W_dec_norm(self): # need to deal with the jumprelu having a log_threshold in training - if self.cfg.architecture == "jumprelu": + if self.cfg.architecture == "jumprelu" or self.cfg.architecture == "singular_fisher": cur_threshold = self.threshold.clone() W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1) super().fold_W_dec_norm()