Spaces:
Sleeping
Sleeping
Upload sae_utils.py
Browse files- SAE/sae_utils.py +3 -1
SAE/sae_utils.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import torch
|
| 2 |
from dataclasses import dataclass, field
|
|
|
|
| 3 |
|
| 4 |
@dataclass
|
| 5 |
class SAETrainingConfig:
|
|
@@ -23,6 +24,7 @@ class SAETrainingConfig:
|
|
| 23 |
def save_path(self):
|
| 24 |
return os.path.join(save_path_base, f'{self.block_name}_k{self.k}_hidden{self.n_dirs}_auxk{self.auxk}_bs{self.bs}_lr{self.lr}')
|
| 25 |
|
|
|
|
| 26 |
@dataclass
|
| 27 |
class Config:
|
| 28 |
saes: list[SAETrainingConfig]
|
|
@@ -43,4 +45,4 @@ class Config:
|
|
| 43 |
self.log_interval = cfg_json['log_interval']
|
| 44 |
self.save_interval = cfg_json['save_interval']
|
| 45 |
self.bs = cfg_json['bs']
|
| 46 |
-
self.block_name = cfg_json['block_name']
|
|
|
|
| 1 |
import torch
|
| 2 |
from dataclasses import dataclass, field
|
| 3 |
+
import os
|
| 4 |
|
| 5 |
@dataclass
|
| 6 |
class SAETrainingConfig:
|
|
|
|
| 24 |
def save_path(self):
|
| 25 |
return os.path.join(save_path_base, f'{self.block_name}_k{self.k}_hidden{self.n_dirs}_auxk{self.auxk}_bs{self.bs}_lr{self.lr}')
|
| 26 |
|
| 27 |
+
|
| 28 |
@dataclass
|
| 29 |
class Config:
|
| 30 |
saes: list[SAETrainingConfig]
|
|
|
|
| 45 |
self.log_interval = cfg_json['log_interval']
|
| 46 |
self.save_interval = cfg_json['save_interval']
|
| 47 |
self.bs = cfg_json['bs']
|
| 48 |
+
self.block_name = cfg_json['block_name']
|