anonymous-author-129 commited on
Commit
3bb4561
·
verified ·
1 Parent(s): ea0fe51

Upload sae_utils.py

Browse files
Files changed (1) hide show
  1. SAE/sae_utils.py +3 -1
SAE/sae_utils.py CHANGED
@@ -1,5 +1,6 @@
1
  import torch
2
  from dataclasses import dataclass, field
 
3
 
4
  @dataclass
5
  class SAETrainingConfig:
@@ -23,6 +24,7 @@ class SAETrainingConfig:
23
  def save_path(self):
24
  return os.path.join(save_path_base, f'{self.block_name}_k{self.k}_hidden{self.n_dirs}_auxk{self.auxk}_bs{self.bs}_lr{self.lr}')
25
 
 
26
  @dataclass
27
  class Config:
28
  saes: list[SAETrainingConfig]
@@ -43,4 +45,4 @@ class Config:
43
  self.log_interval = cfg_json['log_interval']
44
  self.save_interval = cfg_json['save_interval']
45
  self.bs = cfg_json['bs']
46
- self.block_name = cfg_json['block_name']
 
1
  import torch
2
  from dataclasses import dataclass, field
3
+ import os
4
 
5
  @dataclass
6
  class SAETrainingConfig:
 
24
  def save_path(self):
25
  return os.path.join(save_path_base, f'{self.block_name}_k{self.k}_hidden{self.n_dirs}_auxk{self.auxk}_bs{self.bs}_lr{self.lr}')
26
 
27
+
28
  @dataclass
29
  class Config:
30
  saes: list[SAETrainingConfig]
 
45
  self.log_interval = cfg_json['log_interval']
46
  self.save_interval = cfg_json['save_interval']
47
  self.bs = cfg_json['bs']
48
+ self.block_name = cfg_json['block_name']