Spaces:
Running
on
Zero
Running
on
Zero
Update models/attn_model.py
Browse files- models/attn_model.py +6 -2
models/attn_model.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import torch
|
| 2 |
from .model import Model
|
| 3 |
from .utils import sample_token, get_last_attn
|
|
@@ -5,6 +6,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
| 5 |
import torch.nn.functional as F
|
| 6 |
|
| 7 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
| 8 |
|
| 9 |
class AttentionModel(Model):
|
| 10 |
def __init__(self, config):
|
|
@@ -12,12 +14,14 @@ class AttentionModel(Model):
|
|
| 12 |
self.name = config["model_info"]["name"]
|
| 13 |
self.max_output_tokens = int(config["params"]["max_output_tokens"])
|
| 14 |
model_id = config["model_info"]["model_id"]
|
| 15 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_id
|
|
|
|
| 16 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 17 |
model_id,
|
| 18 |
torch_dtype=torch.bfloat16,
|
| 19 |
device_map=device,
|
| 20 |
-
attn_implementation="eager"
|
|
|
|
| 21 |
).eval()
|
| 22 |
if config["params"]["important_heads"] == "all":
|
| 23 |
attn_size = self.get_map_dim()
|
|
|
|
| 1 |
+
import os
|
| 2 |
import torch
|
| 3 |
from .model import Model
|
| 4 |
from .utils import sample_token, get_last_attn
|
|
|
|
| 6 |
import torch.nn.functional as F
|
| 7 |
|
| 8 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 9 |
+
token = os.getenv("HF_TOKEN")
|
| 10 |
|
| 11 |
class AttentionModel(Model):
|
| 12 |
def __init__(self, config):
|
|
|
|
| 14 |
self.name = config["model_info"]["name"]
|
| 15 |
self.max_output_tokens = int(config["params"]["max_output_tokens"])
|
| 16 |
model_id = config["model_info"]["model_id"]
|
| 17 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_id,
|
| 18 |
+
use_auth_token=token)
|
| 19 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 20 |
model_id,
|
| 21 |
torch_dtype=torch.bfloat16,
|
| 22 |
device_map=device,
|
| 23 |
+
attn_implementation="eager",
|
| 24 |
+
use_auth_token=token
|
| 25 |
).eval()
|
| 26 |
if config["params"]["important_heads"] == "all":
|
| 27 |
attn_size = self.get_map_dim()
|