Skip to content

Commit 0c91973

Browse files
feat: deprecates eval_hook (#511)
1 parent c361af6 commit 0c91973

File tree

4 files changed

+21
-4
lines changed

4 files changed

+21
-4
lines changed

sae_lens/config.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import math
3+
import warnings
34
from dataclasses import asdict, dataclass, field
45
from pathlib import Path
56
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
@@ -125,7 +126,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
125126
model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub.
126127
model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
127128
hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
128-
hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
129+
hook_eval (str): DEPRECATED: Will be removed in v7.0.0. NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
129130
hook_head_index (int, optional): When the hook is for an activation with a head index, we can specify a specific head to use here.
130131
dataset_path (str): A Hugging Face dataset path.
131132
dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
@@ -264,6 +265,14 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
264265
exclude_special_tokens: bool | list[int] = False
265266

266267
def __post_init__(self):
268+
if self.hook_eval != "NOT_IN_USE":
269+
warnings.warn(
270+
"The 'hook_eval' field is deprecated and will be removed in v7.0.0. "
271+
"It is not currently used and can be safely removed from your config.",
272+
DeprecationWarning,
273+
stacklevel=2,
274+
)
275+
267276
if self.use_cached_activations and self.cached_activations_path is None:
268277
self.cached_activations_path = _default_cached_activations_path(
269278
self.dataset_path,

tests/helpers.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ class LanguageModelSAERunnerConfigDict(TypedDict, total=False):
2727
model_name: str
2828
model_class_name: str
2929
hook_name: str
30-
hook_eval: str
3130
hook_head_index: int | None
3231
dataset_path: str
3332
dataset_trust_remote_code: bool
@@ -118,7 +117,6 @@ def _get_default_runner_config() -> LanguageModelSAERunnerConfigDict:
118117
"model_name": TINYSTORIES_MODEL,
119118
"model_class_name": "HookedTransformer",
120119
"hook_name": "blocks.0.hook_mlp_out",
121-
"hook_eval": "NOT_IN_USE",
122120
"hook_head_index": None,
123121
"dataset_path": NEEL_NANDA_C4_10K_DATASET,
124122
"streaming": False,

tests/training/test_config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,17 @@ def test_sae_training_runner_config_seqpos(
3434
)
3535

3636

37+
def test_LanguageModelSAERunnerConfig_hook_eval_deprecated_usage():
38+
with pytest.warns(
39+
DeprecationWarning,
40+
match="The 'hook_eval' field is deprecated and will be removed in v7.0.0. ",
41+
):
42+
LanguageModelSAERunnerConfig(
43+
sae=StandardTrainingSAEConfig(d_in=10, d_sae=10),
44+
hook_eval="blocks.0.hook_output",
45+
)
46+
47+
3748
@pytest.mark.parametrize("seqpos_slice, expected_error", test_cases_for_seqpos)
3849
def test_cache_activations_runner_config_seqpos(
3950
seqpos_slice: tuple[int, int],

tutorials/mamba_train_example.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
model_name="state-spaces/mamba-370m",
2424
model_class_name="HookedMamba",
2525
hook_name="blocks.39.hook_ssm_input",
26-
hook_eval="blocks.39.hook_ssm_output", # we compare this when replace hook_point activations with autoencode.decode(autoencoder.encode( hook_point activations))
2726
dataset_path="NeelNanda/openwebtext-tokenized-9b",
2827
is_dataset_tokenized=True,
2928
# SAE Parameters

0 commit comments

Comments
 (0)