diff --git a/CHANGELOG.md b/CHANGELOG.md
index 9331cfbe7..5c960173c 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,6 +2,17 @@
+## v5.10.3 (2025-05-19)
+
+### Fix
+
+* fix: fix abnormally large llamascope L0. (#483)
+
+* make llamascope (base model) compatible with sae.fold_activation_norm_scaling_factor
+
+* make llamascope (base model) compatible with sae.fold_activation_norm_scaling_factor ([`4cafc09`](https://github.com/jbloomAus/SAELens/commit/4cafc09e2bd0c6e265fa2b1234f23fb587b2842a))
+
+
## v5.10.2 (2025-05-12)
### Fix
diff --git a/check_open_ai_sae_metrics.ipynb b/check_open_ai_sae_metrics.ipynb
deleted file mode 100644
index 8b7fef44f..000000000
--- a/check_open_ai_sae_metrics.ipynb
+++ /dev/null
@@ -1,4338 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [],
- "source": [
- "from sae_lens.toolkit.pretrained_saes import load_sparsity\n",
- "import plotly.express as px\n",
- "\n",
- "path = \"open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_0\"\n",
- "\n",
- "\n",
- "sparsity = load_sparsity(path) # [\"sparsity\"]\n",
- "\n",
- "\n",
- "px.histogram(sparsity.cpu(), nbins=100).write_html(\"sparsity_histogram.html\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": 57,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "OAI_GPT2Small_v5_32k_resid_delta_mlp\n",
- "Loading page (1/2)\n",
- "Rendering (2/2) \n",
- "Done \n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " version | \n",
- " d_sae | \n",
- " layer | \n",
- " kl_div_with_sae | \n",
- " kl_div_with_ablation | \n",
- " ce_loss_with_sae | \n",
- " ce_loss_without_sae | \n",
- " ce_loss_with_ablation | \n",
- " kl_div_score | \n",
- " ce_loss_score | \n",
- " l2_norm_in | \n",
- " l2_norm_out | \n",
- " l2_ratio | \n",
- " l0 | \n",
- " l1 | \n",
- " explained_variance | \n",
- " mse | \n",
- " total_tokens_evaluated | \n",
- " filepath | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_0/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 0 | \n",
- " 0.004845 | \n",
- " 3.094083 | \n",
- " 3.605465 | \n",
- " 3.599065 | \n",
- " 6.694649 | \n",
- " 0.998434 | \n",
- " 0.997933 | \n",
- " 29.933449 | \n",
- " 29.601543 | \n",
- " 0.989371 | \n",
- " 32.000000 | \n",
- " 71.211151 | \n",
- " 0.966797 | \n",
- " 21.729292 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_0/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_1/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 1 | \n",
- " 0.006601 | \n",
- " 0.051053 | \n",
- " 3.605596 | \n",
- " 3.599065 | \n",
- " 3.652537 | \n",
- " 0.870694 | \n",
- " 0.877862 | \n",
- " 18.973736 | \n",
- " 17.917168 | \n",
- " 0.910649 | \n",
- " 32.000000 | \n",
- " 86.565331 | \n",
- " 0.885442 | \n",
- " 25.637442 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_1/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_2/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 2 | \n",
- " 0.009369 | \n",
- " 0.058747 | \n",
- " 3.601879 | \n",
- " 3.599065 | \n",
- " 3.645913 | \n",
- " 0.840524 | \n",
- " 0.939922 | \n",
- " 49.106537 | \n",
- " 47.644482 | \n",
- " 0.888798 | \n",
- " 31.875000 | \n",
- " 85.811630 | \n",
- " 0.974547 | \n",
- " 37.837296 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_2/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_3/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 3 | \n",
- " 0.010681 | \n",
- " 0.070592 | \n",
- " 3.609601 | \n",
- " 3.599065 | \n",
- " 3.658678 | \n",
- " 0.848690 | \n",
- " 0.823245 | \n",
- " 16.987318 | \n",
- " 15.157210 | \n",
- " 0.874669 | \n",
- " 31.911459 | \n",
- " 85.938217 | \n",
- " 0.780534 | \n",
- " 50.548058 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_3/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_4/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 4 | \n",
- " 0.012658 | \n",
- " 0.063325 | \n",
- " 3.611159 | \n",
- " 3.599065 | \n",
- " 3.660080 | \n",
- " 0.800111 | \n",
- " 0.801781 | \n",
- " 17.251986 | \n",
- " 15.012179 | \n",
- " 0.852544 | \n",
- " 31.955566 | \n",
- " 82.476707 | \n",
- " 0.729496 | \n",
- " 63.704514 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_4/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_5/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 5 | \n",
- " 0.014467 | \n",
- " 0.068505 | \n",
- " 3.613976 | \n",
- " 3.599065 | \n",
- " 3.669386 | \n",
- " 0.788825 | \n",
- " 0.787950 | \n",
- " 18.888968 | \n",
- " 16.209919 | \n",
- " 0.848440 | \n",
- " 32.000000 | \n",
- " 81.434013 | \n",
- " 0.717422 | \n",
- " 87.281723 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_5/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_6/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 6 | \n",
- " 0.016600 | \n",
- " 0.075694 | \n",
- " 3.618799 | \n",
- " 3.599065 | \n",
- " 3.676516 | \n",
- " 0.780703 | \n",
- " 0.745207 | \n",
- " 21.466564 | \n",
- " 18.402473 | \n",
- " 0.852635 | \n",
- " 32.000000 | \n",
- " 78.829765 | \n",
- " 0.706308 | \n",
- " 117.072495 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_6/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_7/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 7 | \n",
- " 0.017010 | \n",
- " 0.080486 | \n",
- " 3.614976 | \n",
- " 3.599065 | \n",
- " 3.672712 | \n",
- " 0.788663 | \n",
- " 0.783952 | \n",
- " 25.444439 | \n",
- " 22.004990 | \n",
- " 0.862489 | \n",
- " 32.000000 | \n",
- " 76.419937 | \n",
- " 0.718003 | \n",
- " 157.791412 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_7/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_8/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 8 | \n",
- " 0.018103 | \n",
- " 0.087324 | \n",
- " 3.616245 | \n",
- " 3.599065 | \n",
- " 3.680337 | \n",
- " 0.792688 | \n",
- " 0.788606 | \n",
- " 30.250225 | \n",
- " 26.306936 | \n",
- " 0.867637 | \n",
- " 32.000000 | \n",
- " 76.728195 | \n",
- " 0.723916 | \n",
- " 219.982910 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_8/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_9/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 9 | \n",
- " 0.019997 | \n",
- " 0.097589 | \n",
- " 3.617456 | \n",
- " 3.599065 | \n",
- " 3.696245 | \n",
- " 0.795088 | \n",
- " 0.810751 | \n",
- " 40.192413 | \n",
- " 35.945808 | \n",
- " 0.889800 | \n",
- " 32.000000 | \n",
- " 72.426567 | \n",
- " 0.742352 | \n",
- " 318.143433 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_9/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_10/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 10 | \n",
- " 0.023115 | \n",
- " 0.126748 | \n",
- " 3.617172 | \n",
- " 3.599065 | \n",
- " 3.708984 | \n",
- " 0.817629 | \n",
- " 0.835264 | \n",
- " 81.756828 | \n",
- " 78.393089 | \n",
- " 0.955360 | \n",
- " 32.000000 | \n",
- " 50.458115 | \n",
- " 0.792657 | \n",
- " 514.553589 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_10/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_11/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 11 | \n",
- " 0.028953 | \n",
- " 0.173841 | \n",
- " 3.623718 | \n",
- " 3.599065 | \n",
- " 3.783318 | \n",
- " 0.833454 | \n",
- " 0.866197 | \n",
- " 92.906296 | \n",
- " 87.663773 | \n",
- " 0.923381 | \n",
- " 32.000000 | \n",
- " 73.987030 | \n",
- " 0.840599 | \n",
- " 742.957520 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_mlp/v5_32k_layer_11/metrics.json | \n",
- "
\n",
- " \n",
- "
\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "OAI_GPT2Small_v5_128k_resid_delta_mlp\n",
- "Loading page (1/2)\n",
- "Rendering (2/2) \n",
- "Done \n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- " \n",
- " \n",
- " | \n",
- " version | \n",
- " d_sae | \n",
- " layer | \n",
- " kl_div_with_sae | \n",
- " kl_div_with_ablation | \n",
- " ce_loss_with_sae | \n",
- " ce_loss_without_sae | \n",
- " ce_loss_with_ablation | \n",
- " kl_div_score | \n",
- " ce_loss_score | \n",
- " l2_norm_in | \n",
- " l2_norm_out | \n",
- " l2_ratio | \n",
- " l0 | \n",
- " l1 | \n",
- " explained_variance | \n",
- " mse | \n",
- " total_tokens_evaluated | \n",
- " filepath | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_0/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 0 | \n",
- " 0.002883 | \n",
- " 3.094083 | \n",
- " 3.601894 | \n",
- " 3.599065 | \n",
- " 6.694649 | \n",
- " 0.999068 | \n",
- " 0.999086 | \n",
- " 29.933449 | \n",
- " 29.719198 | \n",
- " 0.993006 | \n",
- " 31.999350 | \n",
- " 61.198639 | \n",
- " 0.977519 | \n",
- " 13.276111 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_0/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_1/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 1 | \n",
- " 0.004242 | \n",
- " 0.051053 | \n",
- " 3.599821 | \n",
- " 3.599065 | \n",
- " 3.652537 | \n",
- " 0.916919 | \n",
- " 0.985857 | \n",
- " 18.973736 | \n",
- " 18.205364 | \n",
- " 0.932765 | \n",
- " 32.000000 | \n",
- " 84.060181 | \n",
- " 0.916917 | \n",
- " 17.682375 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_1/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_2/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 2 | \n",
- " 0.006752 | \n",
- " 0.058747 | \n",
- " 3.602034 | \n",
- " 3.599065 | \n",
- " 3.645913 | \n",
- " 0.885070 | \n",
- " 0.936614 | \n",
- " 49.106537 | \n",
- " 47.976685 | \n",
- " 0.912467 | \n",
- " 31.984375 | \n",
- " 82.676140 | \n",
- " 0.981090 | \n",
- " 28.213791 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_2/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_3/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 3 | \n",
- " 0.008532 | \n",
- " 0.070592 | \n",
- " 3.607723 | \n",
- " 3.599065 | \n",
- " 3.658678 | \n",
- " 0.879137 | \n",
- " 0.854756 | \n",
- " 16.987318 | \n",
- " 15.537837 | \n",
- " 0.899518 | \n",
- " 31.905111 | \n",
- " 81.886444 | \n",
- " 0.827494 | \n",
- " 38.814980 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_3/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_4/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 4 | \n",
- " 0.010456 | \n",
- " 0.063325 | \n",
- " 3.611892 | \n",
- " 3.599065 | \n",
- " 3.660080 | \n",
- " 0.834885 | \n",
- " 0.789762 | \n",
- " 17.251986 | \n",
- " 15.434065 | \n",
- " 0.879446 | \n",
- " 31.984375 | \n",
- " 78.371658 | \n",
- " 0.780778 | \n",
- " 50.818203 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_4/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_5/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 5 | \n",
- " 0.011618 | \n",
- " 0.068505 | \n",
- " 3.609259 | \n",
- " 3.599065 | \n",
- " 3.669386 | \n",
- " 0.830411 | \n",
- " 0.855026 | \n",
- " 18.888968 | \n",
- " 16.669851 | \n",
- " 0.873767 | \n",
- " 31.983725 | \n",
- " 77.139160 | \n",
- " 0.765372 | \n",
- " 71.502914 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_5/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_6/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 6 | \n",
- " 0.014023 | \n",
- " 0.075694 | \n",
- " 3.614241 | \n",
- " 3.599065 | \n",
- " 3.676516 | \n",
- " 0.814737 | \n",
- " 0.804058 | \n",
- " 21.466564 | \n",
- " 18.890602 | \n",
- " 0.875795 | \n",
- " 31.998373 | \n",
- " 74.617645 | \n",
- " 0.754052 | \n",
- " 97.221252 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_6/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_7/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 7 | \n",
- " 0.014005 | \n",
- " 0.080486 | \n",
- " 3.609832 | \n",
- " 3.599065 | \n",
- " 3.672712 | \n",
- " 0.825992 | \n",
- " 0.853803 | \n",
- " 25.444439 | \n",
- " 22.559195 | \n",
- " 0.884380 | \n",
- " 31.997885 | \n",
- " 72.551300 | \n",
- " 0.764806 | \n",
- " 130.780823 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_7/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_8/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 8 | \n",
- " 0.014951 | \n",
- " 0.087324 | \n",
- " 3.615898 | \n",
- " 3.599065 | \n",
- " 3.680337 | \n",
- " 0.828790 | \n",
- " 0.792874 | \n",
- " 30.250225 | \n",
- " 26.934555 | \n",
- " 0.888432 | \n",
- " 31.996908 | \n",
- " 72.559128 | \n",
- " 0.768030 | \n",
- " 183.634979 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_8/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_9/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 9 | \n",
- " 0.015919 | \n",
- " 0.097589 | \n",
- " 3.615844 | \n",
- " 3.599065 | \n",
- " 3.696245 | \n",
- " 0.836876 | \n",
- " 0.827333 | \n",
- " 40.192413 | \n",
- " 36.576889 | \n",
- " 0.905792 | \n",
- " 32.000000 | \n",
- " 68.976295 | \n",
- " 0.782096 | \n",
- " 267.216095 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_9/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_10/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 10 | \n",
- " 0.019077 | \n",
- " 0.126748 | \n",
- " 3.616906 | \n",
- " 3.599065 | \n",
- " 3.708984 | \n",
- " 0.849488 | \n",
- " 0.837687 | \n",
- " 81.756828 | \n",
- " 78.834885 | \n",
- " 0.960933 | \n",
- " 32.000000 | \n",
- " 48.323631 | \n",
- " 0.819787 | \n",
- " 443.108887 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_10/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_11/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 11 | \n",
- " 0.024072 | \n",
- " 0.173841 | \n",
- " 3.620230 | \n",
- " 3.599065 | \n",
- " 3.783318 | \n",
- " 0.861531 | \n",
- " 0.885130 | \n",
- " 92.906296 | \n",
- " 88.290031 | \n",
- " 0.932119 | \n",
- " 32.000000 | \n",
- " 71.255219 | \n",
- " 0.861334 | \n",
- " 635.459473 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_mlp/v5_128k_layer_11/metrics.json | \n",
- "
\n",
- " \n",
- "
\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "OAI_GPT2Small_v5_32k_resid_delta_attn\n",
- "Loading page (1/2)\n",
- "Rendering (2/2) \n",
- "Done \n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- " \n",
- " \n",
- " | \n",
- " version | \n",
- " d_sae | \n",
- " layer | \n",
- " kl_div_with_sae | \n",
- " kl_div_with_ablation | \n",
- " ce_loss_with_sae | \n",
- " ce_loss_without_sae | \n",
- " ce_loss_with_ablation | \n",
- " kl_div_score | \n",
- " ce_loss_score | \n",
- " l2_norm_in | \n",
- " l2_norm_out | \n",
- " l2_ratio | \n",
- " l0 | \n",
- " l1 | \n",
- " explained_variance | \n",
- " mse | \n",
- " total_tokens_evaluated | \n",
- " filepath | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_0/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 0 | \n",
- " 0.004215 | \n",
- " 2.121528 | \n",
- " 3.603763 | \n",
- " 3.599065 | \n",
- " 5.748601 | \n",
- " 0.998013 | \n",
- " 0.997814 | \n",
- " 32.013138 | \n",
- " 31.891546 | \n",
- " 0.996240 | \n",
- " 31.993979 | \n",
- " 42.146969 | \n",
- " 0.966712 | \n",
- " 8.074387 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_0/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_1/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 1 | \n",
- " 0.001881 | \n",
- " 0.024066 | \n",
- " 3.601315 | \n",
- " 3.599065 | \n",
- " 3.620633 | \n",
- " 0.921821 | \n",
- " 0.895681 | \n",
- " 9.714648 | \n",
- " 9.157854 | \n",
- " 0.937912 | \n",
- " 32.000000 | \n",
- " 82.860558 | \n",
- " 0.868087 | \n",
- " 9.541979 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_1/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_2/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 2 | \n",
- " 0.002342 | \n",
- " 0.031005 | \n",
- " 3.600916 | \n",
- " 3.599065 | \n",
- " 3.626660 | \n",
- " 0.924478 | \n",
- " 0.932912 | \n",
- " 8.641823 | \n",
- " 8.045539 | \n",
- " 0.929619 | \n",
- " 32.000000 | \n",
- " 82.863617 | \n",
- " 0.853233 | \n",
- " 9.740603 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_2/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_3/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 3 | \n",
- " 0.002842 | \n",
- " 0.025134 | \n",
- " 3.602360 | \n",
- " 3.599065 | \n",
- " 3.628661 | \n",
- " 0.886914 | \n",
- " 0.888638 | \n",
- " 8.571012 | \n",
- " 7.753783 | \n",
- " 0.904760 | \n",
- " 32.000000 | \n",
- " 81.563385 | \n",
- " 0.815689 | \n",
- " 13.545696 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_3/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_4/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 4 | \n",
- " 0.003790 | \n",
- " 0.026723 | \n",
- " 3.603180 | \n",
- " 3.599065 | \n",
- " 3.632133 | \n",
- " 0.858169 | \n",
- " 0.875538 | \n",
- " 9.123016 | \n",
- " 7.993571 | \n",
- " 0.877799 | \n",
- " 32.000000 | \n",
- " 79.127533 | \n",
- " 0.772355 | \n",
- " 19.599684 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_4/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_5/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 5 | \n",
- " 0.004055 | \n",
- " 0.031378 | \n",
- " 3.602063 | \n",
- " 3.599065 | \n",
- " 3.627760 | \n",
- " 0.870766 | \n",
- " 0.895510 | \n",
- " 10.034396 | \n",
- " 8.880257 | \n",
- " 0.886243 | \n",
- " 32.000000 | \n",
- " 77.734062 | \n",
- " 0.782550 | \n",
- " 24.739639 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_5/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_6/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 6 | \n",
- " 0.005056 | \n",
- " 0.032789 | \n",
- " 3.604352 | \n",
- " 3.599065 | \n",
- " 3.634286 | \n",
- " 0.845789 | \n",
- " 0.849889 | \n",
- " 11.678066 | \n",
- " 10.210135 | \n",
- " 0.877007 | \n",
- " 32.000000 | \n",
- " 74.857086 | \n",
- " 0.753413 | \n",
- " 35.039909 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_6/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_7/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 7 | \n",
- " 0.004876 | \n",
- " 0.034661 | \n",
- " 3.605160 | \n",
- " 3.599065 | \n",
- " 3.634834 | \n",
- " 0.859336 | \n",
- " 0.829584 | \n",
- " 13.650208 | \n",
- " 12.291288 | \n",
- " 0.902289 | \n",
- " 32.000000 | \n",
- " 71.106369 | \n",
- " 0.783828 | \n",
- " 41.602531 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_7/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_8/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 8 | \n",
- " 0.005558 | \n",
- " 0.029382 | \n",
- " 3.604657 | \n",
- " 3.599065 | \n",
- " 3.625802 | \n",
- " 0.810828 | \n",
- " 0.790831 | \n",
- " 16.137949 | \n",
- " 14.443828 | \n",
- " 0.896318 | \n",
- " 32.000000 | \n",
- " 71.654037 | \n",
- " 0.759009 | \n",
- " 57.211456 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_8/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_9/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 9 | \n",
- " 0.004499 | \n",
- " 0.028918 | \n",
- " 3.601583 | \n",
- " 3.599065 | \n",
- " 3.636501 | \n",
- " 0.844428 | \n",
- " 0.932728 | \n",
- " 20.912498 | \n",
- " 19.139347 | \n",
- " 0.917789 | \n",
- " 32.000000 | \n",
- " 65.829063 | \n",
- " 0.780790 | \n",
- " 77.843826 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_9/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_10/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 10 | \n",
- " 0.003999 | \n",
- " 0.024752 | \n",
- " 3.602678 | \n",
- " 3.599065 | \n",
- " 3.640488 | \n",
- " 0.838441 | \n",
- " 0.912773 | \n",
- " 31.821377 | \n",
- " 30.121130 | \n",
- " 0.945757 | \n",
- " 32.000000 | \n",
- " 55.563881 | \n",
- " 0.819294 | \n",
- " 125.342606 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_10/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_11/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 11 | \n",
- " 0.003772 | \n",
- " 0.106875 | \n",
- " 3.601134 | \n",
- " 3.599065 | \n",
- " 3.730870 | \n",
- " 0.964708 | \n",
- " 0.984299 | \n",
- " 280.864441 | \n",
- " 280.543213 | \n",
- " 0.998668 | \n",
- " 31.687500 | \n",
- " 17.145309 | \n",
- " 0.967849 | \n",
- " 180.271027 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_delta_attn/v5_32k_layer_11/metrics.json | \n",
- "
\n",
- " \n",
- "
\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "OAI_GPT2Small_v5_128k_resid_delta_attn\n",
- "Loading page (1/2)\n",
- "Rendering (2/2) \n",
- "Done \n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- " \n",
- " \n",
- " | \n",
- " version | \n",
- " d_sae | \n",
- " layer | \n",
- " kl_div_with_sae | \n",
- " kl_div_with_ablation | \n",
- " ce_loss_with_sae | \n",
- " ce_loss_without_sae | \n",
- " ce_loss_with_ablation | \n",
- " kl_div_score | \n",
- " ce_loss_score | \n",
- " l2_norm_in | \n",
- " l2_norm_out | \n",
- " l2_ratio | \n",
- " l0 | \n",
- " l1 | \n",
- " explained_variance | \n",
- " mse | \n",
- " total_tokens_evaluated | \n",
- " filepath | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_0/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 0 | \n",
- " 0.003102 | \n",
- " 2.121528 | \n",
- " 3.600031 | \n",
- " 3.599065 | \n",
- " 5.748601 | \n",
- " 0.998538 | \n",
- " 0.999551 | \n",
- " 32.013138 | \n",
- " 31.910549 | \n",
- " 0.996805 | \n",
- " 31.989422 | \n",
- " 39.967445 | \n",
- " 0.973218 | \n",
- " 6.201825 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_0/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_1/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 1 | \n",
- " 0.001700 | \n",
- " 0.024066 | \n",
- " 3.601379 | \n",
- " 3.599065 | \n",
- " 3.620633 | \n",
- " 0.929377 | \n",
- " 0.892674 | \n",
- " 9.714648 | \n",
- " 9.198002 | \n",
- " 0.942320 | \n",
- " 31.999023 | \n",
- " 78.196198 | \n",
- " 0.879798 | \n",
- " 8.594531 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_1/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_2/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 2 | \n",
- " 0.002000 | \n",
- " 0.031005 | \n",
- " 3.599786 | \n",
- " 3.599065 | \n",
- " 3.626660 | \n",
- " 0.935502 | \n",
- " 0.973847 | \n",
- " 8.641823 | \n",
- " 8.086132 | \n",
- " 0.934166 | \n",
- " 31.987143 | \n",
- " 79.245674 | \n",
- " 0.869267 | \n",
- " 8.591892 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_2/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_3/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 3 | \n",
- " 0.002505 | \n",
- " 0.025134 | \n",
- " 3.601374 | \n",
- " 3.599065 | \n",
- " 3.628661 | \n",
- " 0.900320 | \n",
- " 0.921957 | \n",
- " 8.571012 | \n",
- " 7.854372 | \n",
- " 0.916164 | \n",
- " 32.000000 | \n",
- " 77.344894 | \n",
- " 0.837647 | \n",
- " 11.842029 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_3/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_4/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 4 | \n",
- " 0.003484 | \n",
- " 0.026723 | \n",
- " 3.601578 | \n",
- " 3.599065 | \n",
- " 3.632133 | \n",
- " 0.869623 | \n",
- " 0.924002 | \n",
- " 9.123016 | \n",
- " 8.117954 | \n",
- " 0.891417 | \n",
- " 31.999350 | \n",
- " 75.460785 | \n",
- " 0.798166 | \n",
- " 17.370716 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_4/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_5/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 5 | \n",
- " 0.003556 | \n",
- " 0.031378 | \n",
- " 3.601702 | \n",
- " 3.599065 | \n",
- " 3.627760 | \n",
- " 0.886681 | \n",
- " 0.908081 | \n",
- " 10.034396 | \n",
- " 8.960555 | \n",
- " 0.894489 | \n",
- " 31.998373 | \n",
- " 73.329163 | \n",
- " 0.804148 | \n",
- " 22.216118 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_5/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_6/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 6 | \n",
- " 0.004436 | \n",
- " 0.032789 | \n",
- " 3.603125 | \n",
- " 3.599065 | \n",
- " 3.634286 | \n",
- " 0.864723 | \n",
- " 0.884709 | \n",
- " 11.678066 | \n",
- " 10.348961 | \n",
- " 0.888121 | \n",
- " 31.998699 | \n",
- " 70.520447 | \n",
- " 0.776109 | \n",
- " 31.851517 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_6/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_7/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 7 | \n",
- " 0.004356 | \n",
- " 0.034661 | \n",
- " 3.602255 | \n",
- " 3.599065 | \n",
- " 3.634834 | \n",
- " 0.874331 | \n",
- " 0.910796 | \n",
- " 13.650208 | \n",
- " 12.425106 | \n",
- " 0.911562 | \n",
- " 31.999350 | \n",
- " 66.425446 | \n",
- " 0.803248 | \n",
- " 37.544487 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_7/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_8/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 8 | \n",
- " 0.005140 | \n",
- " 0.029382 | \n",
- " 3.602755 | \n",
- " 3.599065 | \n",
- " 3.625802 | \n",
- " 0.825072 | \n",
- " 0.861990 | \n",
- " 16.137949 | \n",
- " 14.539435 | \n",
- " 0.902075 | \n",
- " 32.000000 | \n",
- " 66.759209 | \n",
- " 0.776774 | \n",
- " 52.779171 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_8/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_9/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 9 | \n",
- " 0.004128 | \n",
- " 0.028918 | \n",
- " 3.601983 | \n",
- " 3.599065 | \n",
- " 3.636501 | \n",
- " 0.857246 | \n",
- " 0.922041 | \n",
- " 20.912498 | \n",
- " 19.252647 | \n",
- " 0.923308 | \n",
- " 31.999350 | \n",
- " 60.197113 | \n",
- " 0.800841 | \n",
- " 70.065033 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_9/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_10/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 10 | \n",
- " 0.003733 | \n",
- " 0.024752 | \n",
- " 3.601586 | \n",
- " 3.599065 | \n",
- " 3.640488 | \n",
- " 0.849176 | \n",
- " 0.939128 | \n",
- " 31.821377 | \n",
- " 30.270412 | \n",
- " 0.950137 | \n",
- " 32.000000 | \n",
- " 52.307262 | \n",
- " 0.836825 | \n",
- " 112.167625 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_10/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_11/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 11 | \n",
- " 0.003718 | \n",
- " 0.106875 | \n",
- " 3.600695 | \n",
- " 3.599065 | \n",
- " 3.730870 | \n",
- " 0.965215 | \n",
- " 0.987627 | \n",
- " 280.864441 | \n",
- " 280.557678 | \n",
- " 0.998717 | \n",
- " 31.750000 | \n",
- " 20.949717 | \n",
- " 0.969702 | \n",
- " 169.453140 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_delta_attn/v5_128k_layer_11/metrics.json | \n",
- "
\n",
- " \n",
- "
\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "OAI_GPT2Small_v5_32k_resid_post_attn\n",
- "Loading page (1/2)\n",
- "Rendering (2/2) \n",
- "Done \n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- " \n",
- " \n",
- " | \n",
- " version | \n",
- " d_sae | \n",
- " layer | \n",
- " kl_div_with_sae | \n",
- " kl_div_with_ablation | \n",
- " ce_loss_with_sae | \n",
- " ce_loss_without_sae | \n",
- " ce_loss_with_ablation | \n",
- " kl_div_score | \n",
- " ce_loss_score | \n",
- " l2_norm_in | \n",
- " l2_norm_out | \n",
- " l2_ratio | \n",
- " l0 | \n",
- " l1 | \n",
- " explained_variance | \n",
- " mse | \n",
- " total_tokens_evaluated | \n",
- " filepath | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_0/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 0 | \n",
- " 0.012114 | \n",
- " 12.480284 | \n",
- " 3.614269 | \n",
- " 3.599065 | \n",
- " 15.861977 | \n",
- " 0.999029 | \n",
- " 0.998760 | \n",
- " 32.707962 | \n",
- " 32.566856 | \n",
- " 0.995715 | \n",
- " 31.989258 | \n",
- " 44.145779 | \n",
- " 0.967972 | \n",
- " 8.511816 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_0/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_1/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 1 | \n",
- " 0.010769 | \n",
- " 16.217104 | \n",
- " 3.608070 | \n",
- " 3.599065 | \n",
- " 19.600266 | \n",
- " 0.999336 | \n",
- " 0.999437 | \n",
- " 56.929867 | \n",
- " 56.553204 | \n",
- " 0.993499 | \n",
- " 32.000000 | \n",
- " 50.736076 | \n",
- " 0.960833 | \n",
- " 47.572258 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_1/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_2/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 2 | \n",
- " 0.020159 | \n",
- " 12.813511 | \n",
- " 3.623466 | \n",
- " 3.599065 | \n",
- " 16.327873 | \n",
- " 0.998427 | \n",
- " 0.998083 | \n",
- " 68.907532 | \n",
- " 68.362450 | \n",
- " 0.991143 | \n",
- " 31.999350 | \n",
- " 52.513954 | \n",
- " 0.957198 | \n",
- " 71.102158 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_2/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_3/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 3 | \n",
- " 0.035772 | \n",
- " 10.101868 | \n",
- " 3.636966 | \n",
- " 3.599065 | \n",
- " 13.548822 | \n",
- " 0.996459 | \n",
- " 0.996191 | \n",
- " 103.711441 | \n",
- " 102.807175 | \n",
- " 0.986217 | \n",
- " 31.925781 | \n",
- " 53.469627 | \n",
- " 0.965960 | \n",
- " 125.999268 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_3/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_4/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 4 | \n",
- " 0.057062 | \n",
- " 13.249713 | \n",
- " 3.668521 | \n",
- " 3.599065 | \n",
- " 16.699104 | \n",
- " 0.995693 | \n",
- " 0.994698 | \n",
- " 111.403282 | \n",
- " 109.931808 | \n",
- " 0.979109 | \n",
- " 31.931152 | \n",
- " 55.739113 | \n",
- " 0.951081 | \n",
- " 211.805756 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_4/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_5/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 5 | \n",
- " 0.069624 | \n",
- " 11.519682 | \n",
- " 3.673068 | \n",
- " 3.599065 | \n",
- " 14.860109 | \n",
- " 0.993956 | \n",
- " 0.993428 | \n",
- " 119.651489 | \n",
- " 117.608505 | \n",
- " 0.973188 | \n",
- " 31.877768 | \n",
- " 56.817348 | \n",
- " 0.937413 | \n",
- " 317.223206 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_5/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_6/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 6 | \n",
- " 0.087217 | \n",
- " 6.933250 | \n",
- " 3.691099 | \n",
- " 3.599065 | \n",
- " 10.522690 | \n",
- " 0.987420 | \n",
- " 0.986707 | \n",
- " 128.847931 | \n",
- " 126.086975 | \n",
- " 0.967220 | \n",
- " 31.995281 | \n",
- " 58.897686 | \n",
- " 0.923337 | \n",
- " 455.123871 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_6/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_7/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 7 | \n",
- " 0.102682 | \n",
- " 9.511523 | \n",
- " 3.708053 | \n",
- " 3.599065 | \n",
- " 13.054041 | \n",
- " 0.989204 | \n",
- " 0.988473 | \n",
- " 140.905991 | \n",
- " 137.484528 | \n",
- " 0.964016 | \n",
- " 31.886557 | \n",
- " 56.659546 | \n",
- " 0.910139 | \n",
- " 636.605774 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_7/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_8/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 8 | \n",
- " 0.123620 | \n",
- " 7.897106 | \n",
- " 3.733898 | \n",
- " 3.599065 | \n",
- " 11.460873 | \n",
- " 0.984346 | \n",
- " 0.982850 | \n",
- " 157.343246 | \n",
- " 153.033264 | \n",
- " 0.961115 | \n",
- " 31.998699 | \n",
- " 55.846397 | \n",
- " 0.892514 | \n",
- " 920.340698 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_8/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_9/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 9 | \n",
- " 0.136715 | \n",
- " 5.396312 | \n",
- " 3.744588 | \n",
- " 3.599065 | \n",
- " 8.970472 | \n",
- " 0.974665 | \n",
- " 0.972908 | \n",
- " 181.313721 | \n",
- " 175.997467 | \n",
- " 0.960342 | \n",
- " 31.993652 | \n",
- " 54.271484 | \n",
- " 0.871873 | \n",
- " 1386.967773 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_9/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_10/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 10 | \n",
- " 0.150131 | \n",
- " 6.193092 | \n",
- " 3.750826 | \n",
- " 3.599065 | \n",
- " 9.754217 | \n",
- " 0.975758 | \n",
- " 0.975344 | \n",
- " 224.287598 | \n",
- " 217.772461 | \n",
- " 0.962987 | \n",
- " 32.000000 | \n",
- " 49.592400 | \n",
- " 0.851867 | \n",
- " 2193.096680 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_10/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_11/metrics.json | \n",
- " 5 | \n",
- " 32 | \n",
- " 11 | \n",
- " 0.175661 | \n",
- " 13.087515 | \n",
- " 3.781703 | \n",
- " 3.599065 | \n",
- " 16.484846 | \n",
- " 0.986578 | \n",
- " 0.985826 | \n",
- " 395.539520 | \n",
- " 390.682129 | \n",
- " 0.987008 | \n",
- " 31.990072 | \n",
- " 32.074562 | \n",
- " 0.846984 | \n",
- " 3677.226074 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_32k_resid_post_attn/v5_32k_layer_11/metrics.json | \n",
- "
\n",
- " \n",
- "
\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "OAI_GPT2Small_v5_128k_resid_post_attn\n",
- "Loading page (1/2)\n",
- "Rendering (2/2) \n",
- "Done \n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- " \n",
- " \n",
- " | \n",
- " version | \n",
- " d_sae | \n",
- " layer | \n",
- " kl_div_with_sae | \n",
- " kl_div_with_ablation | \n",
- " ce_loss_with_sae | \n",
- " ce_loss_without_sae | \n",
- " ce_loss_with_ablation | \n",
- " kl_div_score | \n",
- " ce_loss_score | \n",
- " l2_norm_in | \n",
- " l2_norm_out | \n",
- " l2_ratio | \n",
- " l0 | \n",
- " l1 | \n",
- " explained_variance | \n",
- " mse | \n",
- " total_tokens_evaluated | \n",
- " filepath | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_0/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 0 | \n",
- " 0.003843 | \n",
- " 12.480284 | \n",
- " 3.603422 | \n",
- " 3.599065 | \n",
- " 15.861977 | \n",
- " 0.999692 | \n",
- " 0.999645 | \n",
- " 32.707962 | \n",
- " 32.607395 | \n",
- " 0.996936 | \n",
- " 31.980795 | \n",
- " 44.247345 | \n",
- " 0.976495 | \n",
- " 5.842685 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_0/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_1/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 1 | \n",
- " 0.006731 | \n",
- " 16.217104 | \n",
- " 3.605462 | \n",
- " 3.599065 | \n",
- " 19.600266 | \n",
- " 0.999585 | \n",
- " 0.999600 | \n",
- " 56.929867 | \n",
- " 56.693493 | \n",
- " 0.995869 | \n",
- " 31.999023 | \n",
- " 59.953354 | \n",
- " 0.972269 | \n",
- " 31.145605 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_1/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_2/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 2 | \n",
- " 0.014035 | \n",
- " 12.813511 | \n",
- " 3.612803 | \n",
- " 3.599065 | \n",
- " 16.327873 | \n",
- " 0.998905 | \n",
- " 0.998921 | \n",
- " 68.907532 | \n",
- " 68.518921 | \n",
- " 0.993602 | \n",
- " 31.999350 | \n",
- " 53.266327 | \n",
- " 0.969186 | \n",
- " 49.732628 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_2/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_3/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 3 | \n",
- " 0.023511 | \n",
- " 10.101868 | \n",
- " 3.622775 | \n",
- " 3.599065 | \n",
- " 13.548822 | \n",
- " 0.997673 | \n",
- " 0.997617 | \n",
- " 103.711441 | \n",
- " 103.026962 | \n",
- " 0.989538 | \n",
- " 31.983074 | \n",
- " 54.825603 | \n",
- " 0.975295 | \n",
- " 89.620079 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_3/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_4/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 4 | \n",
- " 0.037530 | \n",
- " 13.249713 | \n",
- " 3.640705 | \n",
- " 3.599065 | \n",
- " 16.699104 | \n",
- " 0.997167 | \n",
- " 0.996821 | \n",
- " 111.403282 | \n",
- " 110.286774 | \n",
- " 0.984007 | \n",
- " 31.948568 | \n",
- " 56.497829 | \n",
- " 0.963923 | \n",
- " 153.826294 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_4/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_5/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 5 | \n",
- " 0.047411 | \n",
- " 11.519682 | \n",
- " 3.644166 | \n",
- " 3.599065 | \n",
- " 14.860109 | \n",
- " 0.995884 | \n",
- " 0.995995 | \n",
- " 119.651489 | \n",
- " 118.053459 | \n",
- " 0.978980 | \n",
- " 31.958008 | \n",
- " 56.149471 | \n",
- " 0.952276 | \n",
- " 238.126266 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_5/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_6/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 6 | \n",
- " 0.060794 | \n",
- " 6.933250 | \n",
- " 3.655243 | \n",
- " 3.599065 | \n",
- " 10.522690 | \n",
- " 0.991232 | \n",
- " 0.991886 | \n",
- " 128.847931 | \n",
- " 126.688171 | \n",
- " 0.974231 | \n",
- " 31.996094 | \n",
- " 56.119469 | \n",
- " 0.940335 | \n",
- " 349.569763 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_6/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_7/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 7 | \n",
- " 0.073590 | \n",
- " 9.511523 | \n",
- " 3.668211 | \n",
- " 3.599065 | \n",
- " 13.054041 | \n",
- " 0.992263 | \n",
- " 0.992687 | \n",
- " 140.905991 | \n",
- " 138.128754 | \n",
- " 0.970831 | \n",
- " 31.997070 | \n",
- " 55.752361 | \n",
- " 0.928668 | \n",
- " 499.765717 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_7/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_8/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 8 | \n",
- " 0.089368 | \n",
- " 7.897106 | \n",
- " 3.679746 | \n",
- " 3.599065 | \n",
- " 11.460873 | \n",
- " 0.988683 | \n",
- " 0.989738 | \n",
- " 157.343246 | \n",
- " 153.839539 | \n",
- " 0.968352 | \n",
- " 31.998373 | \n",
- " 53.455093 | \n",
- " 0.913877 | \n",
- " 732.317871 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_8/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_9/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 9 | \n",
- " 0.096923 | \n",
- " 5.396312 | \n",
- " 3.695744 | \n",
- " 3.599065 | \n",
- " 8.970472 | \n",
- " 0.982039 | \n",
- " 0.982001 | \n",
- " 181.313721 | \n",
- " 176.829346 | \n",
- " 0.966578 | \n",
- " 31.997070 | \n",
- " 51.393932 | \n",
- " 0.895696 | \n",
- " 1123.425049 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_9/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_10/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 10 | \n",
- " 0.106645 | \n",
- " 6.193092 | \n",
- " 3.692219 | \n",
- " 3.599065 | \n",
- " 9.754217 | \n",
- " 0.982780 | \n",
- " 0.984866 | \n",
- " 224.287598 | \n",
- " 218.769226 | \n",
- " 0.968813 | \n",
- " 31.998047 | \n",
- " 47.411495 | \n",
- " 0.877616 | \n",
- " 1806.194092 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_10/metrics.json | \n",
- "
\n",
- " \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_11/metrics.json | \n",
- " 5 | \n",
- " 128 | \n",
- " 11 | \n",
- " 0.136668 | \n",
- " 13.087515 | \n",
- " 3.732962 | \n",
- " 3.599065 | \n",
- " 16.484846 | \n",
- " 0.989557 | \n",
- " 0.989609 | \n",
- " 395.539520 | \n",
- " 391.472504 | \n",
- " 0.989259 | \n",
- " 32.000000 | \n",
- " 31.824055 | \n",
- " 0.870424 | \n",
- " 3098.075928 | \n",
- " 6144.000000 | \n",
- " OAI_GPT2Small_v5_128k_resid_post_attn/v5_128k_layer_11/metrics.json | \n",
- "
\n",
- " \n",
- "
\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "import pandas as pd\n",
- "\n",
- "# get all json files in all subfolders of the mother path\n",
- "import os\n",
- "import json\n",
- "from IPython.display import display\n",
- "import imgkit\n",
- "\n",
- "\n",
- "def get_all_json_files(mother_path):\n",
- " json_files = []\n",
- "\n",
- " for root, dirs, files in os.walk(mother_path):\n",
- " for file in files:\n",
- " if file.endswith(\"metrics.json\"):\n",
- " json_files.append(os.path.join(root, file))\n",
- " return json_files\n",
- "\n",
- "\n",
- "def get_benchmark_stats_csv(\n",
- " mother_path=\"open_ai_sae_weights_resid_post_attn_reformatted\",\n",
- "):\n",
- " json_files = get_all_json_files(mother_path)\n",
- " eval_metrics = {}\n",
- "\n",
- " for file in json_files:\n",
- " with open(file, \"r\") as f:\n",
- " data = json.load(f)\n",
- " eval_metrics[file] = data\n",
- "\n",
- " df = pd.DataFrame(eval_metrics).T\n",
- " df[\"filepath\"] = df.index\n",
- " df.head()\n",
- " pattern = r\".*/v(\\d+)_(\\d+)k_layer_(\\d+)/metrics\\.json\"\n",
- "\n",
- " df[[\"version\", \"d_sae\", \"layer\"]] = df.filepath.str.extract(pattern)\n",
- " # move these columns to the start\n",
- " cols = df.columns.tolist()\n",
- " cols = cols[-3:] + cols[:-3]\n",
- " df = df[cols]\n",
- " df[\"layer\"] = df[\"layer\"].astype(int)\n",
- "\n",
- " # remove \"metrics\" prefix from the columns\n",
- " df.columns = [i.replace(\"metrics/\", \"\") for i in df.columns]\n",
- " df.sort_values(by=[\"version\", \"d_sae\", \"layer\"], inplace=True)\n",
- " df.to_csv(os.path.join(mother_path, \"benchmark_stats.csv\"))\n",
- " df.style.background_gradient(cmap=\"viridis\", axis=0).to_html(\n",
- " os.path.join(mother_path, \"benchmark_stats.html\")\n",
- " )\n",
- "\n",
- " # read the html\n",
- " with open(os.path.join(mother_path, \"benchmark_stats.html\"), \"r\") as f:\n",
- " html = f.read()\n",
- " imgkit.from_string(html, os.path.join(mother_path, \"benchmark_stats.png\"))\n",
- "\n",
- " return df.style.background_gradient(cmap=\"viridis\", axis=0)\n",
- "\n",
- "\n",
- "# list all paths that start with OAI in the current fold\n",
- "paths = [i for i in os.listdir(\".\") if i.startswith(\"OAI\")]\n",
- "tables = []\n",
- "for path in paths:\n",
- " print(path)\n",
- " display(get_benchmark_stats_csv(path))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": ".venv",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.11.7"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/content/dashboard_screenshot.png b/content/dashboard_screenshot.png
deleted file mode 100644
index 732217e5f..000000000
Binary files a/content/dashboard_screenshot.png and /dev/null differ
diff --git a/content/readme_screenshot_predict_pronoun_feature.png b/content/readme_screenshot_predict_pronoun_feature.png
deleted file mode 100644
index 904b1277d..000000000
Binary files a/content/readme_screenshot_predict_pronoun_feature.png and /dev/null differ
diff --git a/docs/feature_dashboards.md b/docs/feature_dashboards.md
deleted file mode 100644
index 05652a9d8..000000000
--- a/docs/feature_dashboards.md
+++ /dev/null
@@ -1,8 +0,0 @@
-
-## Example Output
-
-Here's one feature we found in the residual stream of Layer 10 of GPT-2 Small:
-
-. Open `gpt2_resid_pre10_predict_pronoun_feature.html` in your browser to interact with the dashboard (WIP).
-
-Note, probably this feature could split into more mono-semantic features in a larger SAE that had been trained for longer. (this was was only about 49152 features trained on 10M tokens from OpenWebText).
diff --git a/docs/reference.md b/docs/reference.md
deleted file mode 100644
index e69de29bb..000000000
diff --git a/eval_metrics_resid_mid_oai.csv b/eval_metrics_resid_mid_oai.csv
deleted file mode 100644
index 1221f4c4a..000000000
--- a/eval_metrics_resid_mid_oai.csv
+++ /dev/null
@@ -1,25 +0,0 @@
-,version,d_sae,layer,metrics/kl_div_with_sae,metrics/kl_div_with_ablation,metrics/ce_loss_with_sae,metrics/ce_loss_without_sae,metrics/ce_loss_with_ablation,metrics/kl_div_score,metrics/ce_loss_score,metrics/l2_norm_in,metrics/l2_norm_out,metrics/l2_ratio,metrics/l0,metrics/l1,metrics/explained_variance,metrics/mse,metrics/total_tokens_evaluated,filepath
-open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_0/metrics.json,5,128,0,0.0038433405570685863,12.480283737182617,3.603421926498413,3.599064588546753,15.861976623535156,0.9996920470208848,0.9996446734723997,32.70796203613281,32.60739517211914,0.9969363212585449,31.98079490661621,44.247344970703125,0.9764951467514038,5.842685222625732,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_0/metrics.json
-open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_1/metrics.json,5,128,1,0.006731455214321613,16.217103958129883,3.605462074279785,3.599064588546753,19.600265502929688,0.999584916318493,0.999600187150498,56.929866790771484,56.6934928894043,0.9958688616752625,31.9990234375,59.95335388183594,0.9722690582275391,31.145605087280273,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_1/metrics.json
-open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_2/metrics.json,5,128,2,0.014034643769264221,12.81351089477539,3.6128032207489014,3.599064588546753,16.32787322998047,0.9989046995874498,0.9989206662941394,68.90753173828125,68.5189208984375,0.9936020374298096,31.99934959411621,53.266326904296875,0.9691864848136902,49.732627868652344,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_2/metrics.json
-open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_3/metrics.json,5,128,3,0.023511391133069992,10.10186767578125,3.622774839401245,3.599064588546753,13.548822402954102,0.9976725698764163,0.9976170022128419,103.71144104003906,103.02696228027344,0.9895378351211548,31.983074188232422,54.82560348510742,0.9752952456474304,89.62007904052734,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_3/metrics.json
-open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_4/metrics.json,5,128,4,0.03753046691417694,13.249712944030762,3.640705108642578,3.599064588546753,16.69910430908203,0.9971674505649509,0.9968213439818392,111.40328216552734,110.28677368164062,0.9840068817138672,31.94856834411621,56.49782943725586,0.9639231562614441,153.8262939453125,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_4/metrics.json
-open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_5/metrics.json,5,128,5,0.04741125553846359,11.519681930541992,3.6441657543182373,3.599064588546753,14.860109329223633,0.9958843259888311,0.9959949394740818,119.6514892578125,118.05345916748047,0.9789802432060242,31.9580078125,56.149471282958984,0.9522756338119507,238.1262664794922,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_5/metrics.json
-open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_6/metrics.json,5,128,6,0.06079387292265892,6.933250427246094,3.655242681503296,3.599064588546753,10.522689819335938,0.9912315480941302,0.9918860291994546,128.84793090820312,126.68817138671875,0.9742312431335449,31.99609375,56.119468688964844,0.940334677696228,349.56976318359375,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_6/metrics.json
-open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_7/metrics.json,5,128,7,0.07359001040458679,9.511523246765137,3.668210983276367,3.599064588546753,13.054040908813477,0.9922630678078178,0.9926867722998524,140.90599060058594,138.12875366210938,0.9708306789398193,31.9970703125,55.75236129760742,0.928668200969696,499.7657165527344,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_7/metrics.json
-open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_8/metrics.json,5,128,8,0.08936788886785507,7.897105693817139,3.679746389389038,3.599064588546753,11.460872650146484,0.9886834629884942,0.9897375005583807,157.34324645996094,153.83953857421875,0.9683517217636108,31.99837303161621,53.45509338378906,0.9138767123222351,732.31787109375,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_8/metrics.json
-open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_9/metrics.json,5,128,9,0.09692329168319702,5.3963117599487305,3.6957435607910156,3.599064588546753,8.97047233581543,0.9820389747674404,0.9820011853887981,181.313720703125,176.829345703125,0.966578483581543,31.9970703125,51.3939323425293,0.8956956267356873,1123.425048828125,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_9/metrics.json
-open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_10/metrics.json,5,128,10,0.10664453357458115,6.193092346191406,3.692218780517578,3.599064588546753,9.754217147827148,0.98278008341985,0.9848656566878472,224.28759765625,218.76922607421875,0.9688126444816589,31.998046875,47.411495208740234,0.8776161670684814,1806.194091796875,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_10/metrics.json
-open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_11/metrics.json,5,128,11,0.1366678923368454,13.087514877319336,3.7329623699188232,3.599064588546753,16.484846115112305,0.9895573839939856,0.9896088738509167,395.5395202636719,391.4725036621094,0.9892591834068298,32.0,31.824054718017578,0.8704243898391724,3098.075927734375,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_128k_layer_11/metrics.json
-open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_0/metrics.json,5,32,0,0.012114070355892181,12.480283737182617,3.614269256591797,3.599064588546753,15.861976623535156,0.9990293433538053,0.9987601095072963,32.70796203613281,32.566856384277344,0.9957150220870972,31.9892578125,44.14577865600586,0.967971682548523,8.511816024780273,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_0/metrics.json
-open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_1/metrics.json,5,32,1,0.0107693150639534,16.217103958129883,3.608069896697998,3.599064588546753,19.600265502929688,0.9993359285917043,0.9994372104819239,56.929866790771484,56.55320358276367,0.9934985041618347,32.0,50.73607635498047,0.9608334302902222,47.57225799560547,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_1/metrics.json
-open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_2/metrics.json,5,32,2,0.02015852928161621,12.81351089477539,3.6234664916992188,3.599064588546753,16.32787322998047,0.9984267754991463,0.9980829389584006,68.90753173828125,68.3624496459961,0.9911429286003113,31.99934959411621,52.513954162597656,0.9571983218193054,71.10215759277344,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_2/metrics.json
-open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_3/metrics.json,5,32,3,0.03577210009098053,10.10186767578125,3.63696551322937,3.599064588546753,13.548822402954102,0.9964588627332011,0.9961907691232709,103.71144104003906,102.80717468261719,0.9862173795700073,31.92578125,53.469627380371094,0.9659597277641296,125.999267578125,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_3/metrics.json
-open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_4/metrics.json,5,32,4,0.05706218630075455,13.249712944030762,3.668520927429199,3.599064588546753,16.69910430908203,0.9956933265994671,0.9946980054744744,111.40328216552734,109.93180847167969,0.9791091680526733,31.93115234375,55.739112854003906,0.9510807991027832,211.80575561523438,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_4/metrics.json
-open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_5/metrics.json,5,32,5,0.06962442398071289,11.519681930541992,3.673067569732666,3.599064588546753,14.860109329223633,0.9939560463213729,0.9934284089185259,119.6514892578125,117.60850524902344,0.9731881022453308,31.87776756286621,56.81734848022461,0.9374126195907593,317.22320556640625,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_5/metrics.json
-open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_6/metrics.json,5,32,6,0.08721713721752167,6.933250427246094,3.691099166870117,3.599064588546753,10.522689819335938,0.9874204547877315,0.9867071692566362,128.84793090820312,126.08697509765625,0.9672197699546814,31.995281219482422,58.89768600463867,0.923336923122406,455.1238708496094,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_6/metrics.json
-open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_7/metrics.json,5,32,7,0.10268213599920273,9.511523246765137,3.708052635192871,3.599064588546753,13.054040908813477,0.9892044488211575,0.9884729434580918,140.90599060058594,137.48452758789062,0.9640159010887146,31.88655662536621,56.6595458984375,0.910139262676239,636.6057739257812,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_7/metrics.json
-open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_8/metrics.json,5,32,8,0.12361977994441986,7.897105693817139,3.7338976860046387,3.599064588546753,11.460872650146484,0.9843461915368303,0.9828496070622145,157.34324645996094,153.03326416015625,0.9611150622367859,31.998699188232422,55.846397399902344,0.8925139307975769,920.3406982421875,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_8/metrics.json
-open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_9/metrics.json,5,32,9,0.13671526312828064,5.3963117599487305,3.7445876598358154,3.599064588546753,8.97047233581543,0.974665054724418,0.9729078338238127,181.313720703125,175.99746704101562,0.9603419303894043,31.99365234375,54.271484375,0.8718730807304382,1386.9677734375,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_9/metrics.json
-open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_10/metrics.json,5,32,10,0.15013137459754944,6.193092346191406,3.750826358795166,3.599064588546753,9.754217147827148,0.9757582535177477,0.9753439465899841,224.28759765625,217.7724609375,0.962986946105957,32.0,49.59239959716797,0.851866602897644,2193.0966796875,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_10/metrics.json
-open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_11/metrics.json,5,32,11,0.17566144466400146,13.087514877319336,3.781702756881714,3.599064588546753,16.484846115112305,0.9865779373463466,0.9858263801882384,395.5395202636719,390.68212890625,0.9870077967643738,31.99007225036621,32.074562072753906,0.846983790397644,3677.22607421875,6144.0,open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_11/metrics.json
diff --git a/make_hf_repo.sh b/make_hf_repo.sh
deleted file mode 100755
index 9f681e5d4..000000000
--- a/make_hf_repo.sh
+++ /dev/null
@@ -1,16 +0,0 @@
-#!/bin/bash
-
-# # Set these variables
-# LOCAL_FOLDER="OAI_GPT2Small_v5_32k_resid_delta_attn"
-# REPO_NAME="GPT2-Small-OAI-v5-32k-attn-out-SAEs"
-# USERNAME="jbloom"
-
-
-# It's actually really easy to upload folders
-huggingface-cli repo create GPT2-Small-OAI-v5-128k-attn-out-SAEs
-cd OAI_GPT2Small_v5_128k_resid_delta_attn
-huggingface-cli upload jbloom/GPT2-Small-OAI-v5-128k-attn-out-SAEs .
-
-huggingface-cli repo create GPT2-Small-OAI-v5-128k-resid-mid-SAEs
-cd OAI_GPT2Small_v5_128k_resid_post_attn
-huggingface-cli upload jbloom/GPT2-Small-OAI-v5-128k-resid-mid-SAEs .
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index 7b2fba712..db222d6f0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sae-lens"
-version = "5.10.2"
+version = "5.10.3"
description = "Training and Analyzing Sparse Autoencoders (SAEs)"
authors = ["Joseph Bloom"]
readme = "README.md"
diff --git a/sae_lens/__init__.py b/sae_lens/__init__.py
index 899252ff1..96e18d6d8 100644
--- a/sae_lens/__init__.py
+++ b/sae_lens/__init__.py
@@ -1,5 +1,5 @@
# ruff: noqa: E402
-__version__ = "5.10.2"
+__version__ = "5.10.3"
import logging
diff --git a/sae_lens/config.py b/sae_lens/config.py
index 6faf8ac9b..c79a9c308 100644
--- a/sae_lens/config.py
+++ b/sae_lens/config.py
@@ -99,6 +99,8 @@ class LanguageModelSAERunnerConfig:
prepend_bos (bool): Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.
jumprelu_init_threshold (float): The threshold to initialize for training JumpReLU SAEs.
jumprelu_bandwidth (float): Bandwidth for training JumpReLU SAEs.
+ fisher_lambda_term (float): The lambda term for the Fisher information matrix. Used for redundant feature removal.
+ use_fisher (bool): Whether to use the Fisher information matrix for redundant feature removal.
autocast (bool): Whether to use autocast during training. Saves vram.
autocast_lm (bool): Whether to use autocast during activation fetching.
compile_llm (bool): Whether to compile the LLM.
@@ -198,6 +200,10 @@ class LanguageModelSAERunnerConfig:
jumprelu_init_threshold: float = 0.001
jumprelu_bandwidth: float = 0.001
+ # Likelihood Context
+ fisher_lambda_term: float = 1e-4
+ use_fisher: bool = False # whether to use the Fisher information matrix for redudant feature removal
+
# Performance - see compilation section of lm_runner.py for info
autocast: bool = False # autocast to autocast_dtype during training
autocast_lm: bool = False # autocast lm during activation fetching
@@ -477,6 +483,8 @@ def get_training_sae_cfg_dict(self) -> dict[str, Any]:
"jumprelu_init_threshold": self.jumprelu_init_threshold,
"jumprelu_bandwidth": self.jumprelu_bandwidth,
"scale_sparsity_penalty_by_decoder_norm": self.scale_sparsity_penalty_by_decoder_norm,
+ "fisher_lambda_term": self.fisher_lambda_term,
+ "use_fisher": self.use_fisher,
}
def to_dict(self) -> dict[str, Any]:
diff --git a/sae_lens/evals.py b/sae_lens/evals.py
index 63d43d203..f41b7bbc4 100644
--- a/sae_lens/evals.py
+++ b/sae_lens/evals.py
@@ -712,12 +712,15 @@ def single_head_zero_ablate_hook(activations: torch.Tensor, hook: Any): # noqa:
)
def kl(original_logits: torch.Tensor, new_logits: torch.Tensor):
- original_probs = torch.nn.functional.softmax(original_logits, dim=-1)
- log_original_probs = torch.log(original_probs)
- new_probs = torch.nn.functional.softmax(new_logits, dim=-1)
- log_new_probs = torch.log(new_probs)
- kl_div = original_probs * (log_original_probs - log_new_probs)
- return kl_div.sum(dim=-1)
+ # Computes the log-probabilities of the new logits (approximation).
+ log_probs_new = torch.nn.functional.log_softmax(new_logits, dim=-1)
+ # Computes the probabilities of the original logits (true distribution).
+ probs_orig = torch.nn.functional.softmax(original_logits, dim=-1)
+ # Compute the KL divergence. torch.nn.functional.kl_div expects the first argument to be the log
+ # probabilities of the approximation (new), and the second argument to be the true distribution
+ # (original) as probabilities. This computes KL(original || new).
+ kl = torch.nn.functional.kl_div(log_probs_new, probs_orig, reduction="none")
+ return kl.sum(dim=-1)
if compute_kl:
recons_kl_div = kl(original_logits, recons_logits)
diff --git a/sae_lens/sae.py b/sae_lens/sae.py
index edd873b30..881bc4a4c 100644
--- a/sae_lens/sae.py
+++ b/sae_lens/sae.py
@@ -44,7 +44,7 @@
@dataclass
class SAEConfig:
# architecture details
- architecture: Literal["standard", "gated", "jumprelu", "topk"]
+ architecture: Literal["standard", "gated", "jumprelu", "topk", "singular_fisher"]
# forward pass details.
d_in: int
@@ -172,6 +172,9 @@ def __init__(
elif self.cfg.architecture == "jumprelu":
self.initialize_weights_jumprelu()
self.encode = self.encode_jumprelu
+ elif self.cfg.architecture == "singular_fisher":
+ self.initialize_weights_singular_fisher()
+ self.encode = self.encode_singular_fisher
else:
raise ValueError(f"Invalid architecture: {self.cfg.architecture}")
@@ -321,6 +324,14 @@ def initialize_weights_jumprelu(self):
)
self.initialize_weights_basic()
+ def initialize_weights_singular_fisher(self):
+ # The params are identical to the standard SAE
+ # except we use a threshold parameter too
+ self.threshold = nn.Parameter(
+ torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
+ )
+ self.initialize_weights_basic()
+
@overload
def to(
self: T,
@@ -441,6 +452,22 @@ def encode_standard(
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
+
+ def encode_singular_fisher(
+ self, x: Float[torch.Tensor, "... d_in"]
+ ) -> Float[torch.Tensor, "... d_sae"]:
+ """
+ Calculate SAE features from inputs
+ """
+ sae_in = self.process_sae_in(x)
+
+ # "... d_in, d_in d_sae -> ... d_sae",
+ hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
+
+ return self.hook_sae_acts_post(
+ self.activation_fn(hidden_pre) * (hidden_pre > self.threshold)
+ )
+
def process_sae_in(
self, sae_in: Float[torch.Tensor, "... d_in"]
) -> Float[torch.Tensor, "... d_sae"]:
@@ -475,7 +502,7 @@ def fold_W_dec_norm(self):
self.r_mag.data = self.r_mag.data * W_dec_norms.squeeze()
self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()
- elif self.cfg.architecture == "jumprelu":
+ elif self.cfg.architecture == "jumprelu" or self.cfg.architecture == "singular_fisher":
self.threshold.data = self.threshold.data * W_dec_norms.squeeze()
self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()
else:
diff --git a/sae_lens/toolkit/pretrained_sae_loaders.py b/sae_lens/toolkit/pretrained_sae_loaders.py
index 583f0528a..ffb08f353 100644
--- a/sae_lens/toolkit/pretrained_sae_loaders.py
+++ b/sae_lens/toolkit/pretrained_sae_loaders.py
@@ -473,11 +473,20 @@ def get_llama_scope_config_from_hf(
# Model specific parameters
model_name, d_in = "meta-llama/Llama-3.1-8B", old_cfg_dict["d_model"]
+ # Get norm scaling factor to rescale jumprelu threshold.
+ # We need this because sae.fold_activation_norm_scaling_factor folds scaling norm into W_enc.
+ # This requires jumprelu threshold to be scaled in the same way
+ norm_scaling_factor = (
+ d_in**0.5 / old_cfg_dict["dataset_average_activation_norm"]["in"]
+ )
+
cfg_dict = {
"architecture": "jumprelu",
- "jump_relu_threshold": old_cfg_dict["jump_relu_threshold"],
+ "jump_relu_threshold": old_cfg_dict["jump_relu_threshold"]
+ * norm_scaling_factor,
# We use a scalar jump_relu_threshold for all features
# This is different from Gemma Scope JumpReLU SAEs.
+ # Scaled with norm_scaling_factor to match sae.fold_activation_norm_scaling_factor
"d_in": d_in,
"d_sae": old_cfg_dict["d_sae"],
"dtype": "bfloat16",
diff --git a/sae_lens/training/sae_trainer.py b/sae_lens/training/sae_trainer.py
index eef087e1a..a14a47f23 100644
--- a/sae_lens/training/sae_trainer.py
+++ b/sae_lens/training/sae_trainer.py
@@ -193,7 +193,7 @@ def fit(self) -> TrainingSAE:
self._checkpoint_if_needed()
self.n_training_steps += 1
self._update_pbar(step_output, pbar)
-
+
### If n_training_tokens > sae_group.cfg.training_tokens, then we should switch to fine-tuning (if we haven't already)
self._begin_finetuning_if_needed()
diff --git a/sae_lens/training/training_sae.py b/sae_lens/training/training_sae.py
index c82ad20aa..7136524e3 100644
--- a/sae_lens/training/training_sae.py
+++ b/sae_lens/training/training_sae.py
@@ -121,6 +121,8 @@ class TrainingSAEConfig(SAEConfig):
decoder_heuristic_init_norm: float
init_encoder_as_decoder_transpose: bool
scale_sparsity_penalty_by_decoder_norm: bool
+ fisher_lambda_term: float
+ use_fisher: bool = False
@classmethod
def from_sae_runner_config(
@@ -163,6 +165,8 @@ def from_sae_runner_config(
model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs or {},
jumprelu_init_threshold=cfg.jumprelu_init_threshold,
jumprelu_bandwidth=cfg.jumprelu_bandwidth,
+ fisher_lambda_term=cfg.fisher_lambda_term,
+ use_fisher=cfg.use_fisher,
)
@classmethod
@@ -204,6 +208,8 @@ def to_dict(self) -> dict[str, Any]:
"normalize_activations": self.normalize_activations,
"jumprelu_init_threshold": self.jumprelu_init_threshold,
"jumprelu_bandwidth": self.jumprelu_bandwidth,
+ "fisher_lambda_term": self.fisher_lambda_term,
+ "use_fisher": self.use_fisher,
}
# this needs to exist so we can initialize the parent sae cfg without the training specific
@@ -258,6 +264,12 @@ def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
self.log_threshold.data = torch.ones(
self.cfg.d_sae, dtype=self.dtype, device=self.device
) * np.log(cfg.jumprelu_init_threshold)
+ # elif cfg.architecture == "singular_fisher":
+ # self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre_singular_fisher
+ # self.bandwidth = cfg.jumprelu_bandwidth
+ # self.log_threshold.data = torch.ones(
+ # self.cfg.d_sae, dtype=self.dtype, device=self.device
+ # ) * np.log(cfg.jumprelu_init_threshold)
else:
raise ValueError(f"Unknown architecture: {cfg.architecture}")
@@ -281,9 +293,16 @@ def initialize_weights_jumprelu(self):
)
self.initialize_weights_basic()
+ def initialize_weights_singular_fisher(self):
+ # same as the superclass, except we use a log_threshold parameter instead of threshold
+ self.log_threshold = nn.Parameter(
+ torch.empty(self.cfg.d_sae, dtype=self.dtype, device=self.device)
+ )
+ self.initialize_weights_basic()
+
@property
def threshold(self) -> torch.Tensor:
- if self.cfg.architecture != "jumprelu":
+ if self.cfg.architecture != "jumprelu" and self.cfg.architecture != "singular_fisher":
raise ValueError("Threshold is only defined for Jumprelu SAEs")
return torch.exp(self.log_threshold)
@@ -324,6 +343,26 @@ def encode_with_hidden_pre_jumprelu(
return feature_acts, hidden_pre # type: ignore
+ def encode_with_hidden_pre_singular_fisher(
+ self, x: Float[torch.Tensor, "... d_in"]
+ ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
+ sae_in : Float[torch.Tensor, "... d_in"] = self.process_sae_in(x)
+ hidden_pre : Float[torch.Tensor, "... d_sae"] = sae_in @ self.W_enc + self.b_enc
+
+ if self.training:
+ hidden_pre = (hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale)
+
+ # Calculate the Fisher regularization term
+ self.fisher_reg_term = self.get_fisher_reg_term(hidden_pre)
+
+
+ threshold = torch.exp(self.log_threshold)
+ feature_acts = JumpReLU.apply(hidden_pre, threshold, self.bandwidth)
+
+ return feature_acts, hidden_pre
+
+
+
def encode_with_hidden_pre(
self, x: Float[torch.Tensor, "... d_in"]
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
@@ -357,10 +396,14 @@ def encode_with_hidden_pre_gated(
) # magnitude_pre_activation_noised)
# Return both the gated feature activations and the magnitude pre-activations
- return (
- active_features * feature_magnitudes,
- magnitude_pre_activation,
- ) # magnitude_pre_activation_noised
+ # return (
+ # active_features * feature_magnitudes,
+ # magnitude_pre_activation,
+ # ) # magnitude_pre_activation_noised
+ out_feature_acts = active_features * feature_magnitudes
+ return out_feature_acts, magnitude_pre_activation
+
+
def forward(
self,
@@ -369,6 +412,27 @@ def forward(
feature_acts, _ = self.encode_with_hidden_pre_fn(x)
return self.decode(feature_acts)
+ def get_fisher_reg_term(self, feature_acts) -> torch.Tensor:
+ pseudo_loss = torch.sum(feature_acts**2)
+ grads = torch.autograd.grad(
+ pseudo_loss, feature_acts, create_graph=True, retain_graph=True
+ )[0]
+
+ # Normalize feature activations along the batch (rows)
+ grads = grads - grads.mean(dim=0, keepdim=True)
+ grads = grads / (grads.norm(dim=0, keepdim=True) + 1e-8)
+
+ # Compute empirical Fisher matrix
+ F_empirical = torch.einsum("bi,bj->ij", grads, grads) / feature_acts.shape[0]
+
+ # Remove diagonal (self-correlation)
+ off_diag = F_empirical - torch.diag(torch.diag(F_empirical))
+
+ # Penalize off-diagonal (correlation) terms
+ fisher_reg_term = self.cfg.fisher_lambda_term * torch.sum(off_diag**2)
+
+ return fisher_reg_term
+
def training_forward_pass(
self,
sae_in: torch.Tensor,
@@ -378,6 +442,8 @@ def training_forward_pass(
# do a forward pass to get SAE out, but we also need the
# hidden pre.
feature_acts, hidden_pre = self.encode_with_hidden_pre_fn(sae_in)
+ if self.cfg.use_fisher and self.cfg.architecture != "gated":
+ self.fisher_reg_term = self.get_fisher_reg_term(feature_acts)
sae_out = self.decode(feature_acts)
# MSE LOSS
@@ -407,9 +473,16 @@ def training_forward_pass(
aux_reconstruction_loss = torch.sum(
(via_gate_reconstruction - sae_in) ** 2, dim=-1
).mean()
+
loss = mse_loss + l1_loss + aux_reconstruction_loss
+
losses["auxiliary_reconstruction_loss"] = aux_reconstruction_loss
losses["l1_loss"] = l1_loss
+ if self.cfg.use_fisher:
+ self.fisher_reg_term = self.get_fisher_reg_term(feature_acts)
+ losses["fisher_reg_term"] = self.fisher_reg_term
+ loss += self.fisher_reg_term
+
elif self.cfg.architecture == "jumprelu":
threshold = torch.exp(self.log_threshold)
l0 = torch.sum(Step.apply(hidden_pre, threshold, self.bandwidth), dim=-1) # type: ignore
@@ -425,6 +498,13 @@ def training_forward_pass(
)
losses["auxiliary_reconstruction_loss"] = topk_loss
loss = mse_loss + topk_loss
+ elif self.cfg.architecture == "singular_fisher":
+ threshold = torch.exp(self.log_threshold)
+ l0 = torch.sum(Step.apply(hidden_pre, threshold, self.bandwidth), dim=-1) # type: ignore
+ l0_loss = (current_l1_coefficient * l0).mean()
+ loss = mse_loss + l0_loss + self.fisher_reg_term
+ losses["l0_loss"] = l0_loss
+ losses["fisher_reg_term"] = self.fisher_reg_term
else:
# default SAE sparsity loss
weighted_feature_acts = feature_acts
@@ -453,6 +533,9 @@ def training_forward_pass(
losses["l1_loss"] = l1_loss
losses["mse_loss"] = mse_loss
+ if self.cfg.use_fisher and self.cfg.architecture != "gated":
+ losses["fisher_reg_term"] = self.fisher_reg_term
+ loss = loss + self.fisher_reg_term
return TrainStepOutput(
sae_in=sae_in,
@@ -553,13 +636,13 @@ def batch_norm_mse_loss_fn(
return standard_mse_loss_fn
def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
- if self.cfg.architecture == "jumprelu" and "log_threshold" in state_dict:
+ if self.cfg.architecture in ["jumprelu", "singular_fisher"] and "log_threshold" in state_dict:
threshold = torch.exp(state_dict["log_threshold"]).detach().contiguous()
del state_dict["log_threshold"]
state_dict["threshold"] = threshold
def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
- if self.cfg.architecture == "jumprelu" and "threshold" in state_dict:
+ if self.cfg.architecture in ["jumprelu", "singular_fisher"] and "threshold" in state_dict:
threshold = state_dict["threshold"]
del state_dict["threshold"]
state_dict["log_threshold"] = torch.log(threshold).detach().contiguous()
@@ -627,7 +710,7 @@ def initialize_weights_complex(self):
@torch.no_grad()
def fold_W_dec_norm(self):
# need to deal with the jumprelu having a log_threshold in training
- if self.cfg.architecture == "jumprelu":
+ if self.cfg.architecture == "jumprelu" or self.cfg.architecture == "singular_fisher":
cur_threshold = self.threshold.clone()
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
super().fold_W_dec_norm()