Upload DogeForCausalLM
Browse files- config.json +1 -2
- configuration_doge.py +1 -5
- model.safetensors +1 -1
- modeling_doge.py +12 -20
config.json
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
{
|
| 2 |
-
"_name_or_path": "/root/autodl-tmp/data/Doge-320M",
|
| 3 |
"architectures": [
|
| 4 |
"DogeForCausalLM"
|
| 5 |
],
|
|
@@ -11,7 +11,6 @@
|
|
| 11 |
"bos_token_id": 0,
|
| 12 |
"dynamic_mask_ratio": 0.0,
|
| 13 |
"eos_token_id": 1,
|
| 14 |
-
"expert_retrieval_size": 64,
|
| 15 |
"hidden_act": "silu",
|
| 16 |
"hidden_bias": false,
|
| 17 |
"hidden_dropout": 0.0,
|
|
|
|
| 1 |
{
|
| 2 |
+
"_name_or_path": "/root/autodl-tmp/small-doge/data/Doge-320M-decay/checkpoint-4000",
|
| 3 |
"architectures": [
|
| 4 |
"DogeForCausalLM"
|
| 5 |
],
|
|
|
|
| 11 |
"bos_token_id": 0,
|
| 12 |
"dynamic_mask_ratio": 0.0,
|
| 13 |
"eos_token_id": 1,
|
|
|
|
| 14 |
"hidden_act": "silu",
|
| 15 |
"hidden_bias": false,
|
| 16 |
"hidden_dropout": 0.0,
|
configuration_doge.py
CHANGED
|
@@ -121,8 +121,6 @@ class DogeConfig(PretrainedConfig):
|
|
| 121 |
Number of Experts for the Cross Domain Mixture of Experts.
|
| 122 |
num_experts_per_tok (`int`, *optional*, defaults to 8):
|
| 123 |
Number of selected experts to route per-token.
|
| 124 |
-
expert_retrieval_size (`int`, *optional*, defaults to 64):
|
| 125 |
-
Dimension of the Expert retrieval states for calculating the dot product of query and key to determine the expert index.
|
| 126 |
|
| 127 |
```python
|
| 128 |
>>> from transformers import DogeConfig, DogeModel
|
|
@@ -149,7 +147,7 @@ class DogeConfig(PretrainedConfig):
|
|
| 149 |
"layers.*.feed_forward.gate_proj": "colwise",
|
| 150 |
"layers.*.feed_forward.up_proj": "colwise",
|
| 151 |
"layers.*.feed_forward.down_proj": "rowwise",
|
| 152 |
-
"layers.*.feed_forward.
|
| 153 |
"layers.*.feed_forward.down_embed": "rowwise",
|
| 154 |
"layers.*.feed_forward.up_embed": "rowwise",
|
| 155 |
}
|
|
@@ -181,7 +179,6 @@ class DogeConfig(PretrainedConfig):
|
|
| 181 |
is_moe=False,
|
| 182 |
num_experts=2048,
|
| 183 |
num_experts_per_tok=8,
|
| 184 |
-
expert_retrieval_size=64,
|
| 185 |
**kwargs,
|
| 186 |
):
|
| 187 |
self.vocab_size = vocab_size
|
|
@@ -207,7 +204,6 @@ class DogeConfig(PretrainedConfig):
|
|
| 207 |
self.is_moe = is_moe
|
| 208 |
self.num_experts = num_experts
|
| 209 |
self.num_experts_per_tok = num_experts_per_tok
|
| 210 |
-
self.expert_retrieval_size = expert_retrieval_size
|
| 211 |
|
| 212 |
# Validate the correctness of rotary position embeddings parameters
|
| 213 |
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
|
|
|
| 121 |
Number of Experts for the Cross Domain Mixture of Experts.
|
| 122 |
num_experts_per_tok (`int`, *optional*, defaults to 8):
|
| 123 |
Number of selected experts to route per-token.
|
|
|
|
|
|
|
| 124 |
|
| 125 |
```python
|
| 126 |
>>> from transformers import DogeConfig, DogeModel
|
|
|
|
| 147 |
"layers.*.feed_forward.gate_proj": "colwise",
|
| 148 |
"layers.*.feed_forward.up_proj": "colwise",
|
| 149 |
"layers.*.feed_forward.down_proj": "rowwise",
|
| 150 |
+
"layers.*.feed_forward.router_gate": "colwise",
|
| 151 |
"layers.*.feed_forward.down_embed": "rowwise",
|
| 152 |
"layers.*.feed_forward.up_embed": "rowwise",
|
| 153 |
}
|
|
|
|
| 179 |
is_moe=False,
|
| 180 |
num_experts=2048,
|
| 181 |
num_experts_per_tok=8,
|
|
|
|
| 182 |
**kwargs,
|
| 183 |
):
|
| 184 |
self.vocab_size = vocab_size
|
|
|
|
| 204 |
self.is_moe = is_moe
|
| 205 |
self.num_experts = num_experts
|
| 206 |
self.num_experts_per_tok = num_experts_per_tok
|
|
|
|
| 207 |
|
| 208 |
# Validate the correctness of rotary position embeddings parameters
|
| 209 |
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1343277696
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ce4aaf436761b12719bb9be9d3a250ba388679b324886299ed71f69c2b53a510
|
| 3 |
size 1343277696
|
modeling_doge.py
CHANGED
|
@@ -480,23 +480,17 @@ class DogeCDMoE(DogeMLP):
|
|
| 480 |
self.hidden_dim = config.hidden_size
|
| 481 |
self.act_fn = ACT2FN[config.hidden_act]
|
| 482 |
|
| 483 |
-
self.expert_retrieval_dim = config.expert_retrieval_size
|
| 484 |
self.num_experts = config.num_experts
|
| 485 |
self.top_k = config.num_experts_per_tok
|
| 486 |
self.num_keys = int(math.sqrt(self.num_experts))
|
| 487 |
|
| 488 |
-
#
|
| 489 |
-
self.
|
| 490 |
-
self.keys = nn.Parameter(torch.zeros(2, self.expert_retrieval_dim // 2, self.num_keys))
|
| 491 |
|
| 492 |
# experts
|
| 493 |
self.down_embed = nn.Embedding(self.num_experts, self.hidden_dim)
|
| 494 |
self.up_embed = nn.Embedding(self.num_experts, self.hidden_dim)
|
| 495 |
|
| 496 |
-
# scaling factor
|
| 497 |
-
self.mlp_scaling = nn.Parameter(torch.ones(self.hidden_dim))
|
| 498 |
-
self.moe_scaling = nn.Parameter(torch.zeros(self.hidden_dim))
|
| 499 |
-
|
| 500 |
def forward(
|
| 501 |
self,
|
| 502 |
hidden_states: torch.Tensor,
|
|
@@ -504,27 +498,25 @@ class DogeCDMoE(DogeMLP):
|
|
| 504 |
) -> torch.Tensor:
|
| 505 |
bsz, seq_len, _ = hidden_states.shape
|
| 506 |
|
| 507 |
-
# get routing weights with
|
| 508 |
-
|
| 509 |
-
routing_weights = torch.matmul(queries, self.keys)
|
| 510 |
|
| 511 |
# get experts with the highest routing weights
|
| 512 |
-
(scores_x, scores_y), (indices_x, indices_y) =
|
| 513 |
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
all_indices = all_indices.view(*
|
| 517 |
-
scores,
|
| 518 |
-
|
| 519 |
-
down_embed = self.down_embed(indices).transpose(1, 2)
|
| 520 |
up_embed = self.up_embed(indices)
|
| 521 |
|
| 522 |
# mix experts states with cross domain states
|
| 523 |
-
experts_weights = torch.matmul(hidden_states.view(bsz * seq_len, 1,
|
| 524 |
experts_weights = self.act_fn(experts_weights) * scores.softmax(dim=-1)
|
| 525 |
experts_states = torch.matmul(experts_weights.view(bsz * seq_len, 1, -1), up_embed).view(bsz, seq_len, -1)
|
| 526 |
hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
|
| 527 |
-
hidden_states =
|
| 528 |
return hidden_states
|
| 529 |
|
| 530 |
|
|
|
|
| 480 |
self.hidden_dim = config.hidden_size
|
| 481 |
self.act_fn = ACT2FN[config.hidden_act]
|
| 482 |
|
|
|
|
| 483 |
self.num_experts = config.num_experts
|
| 484 |
self.top_k = config.num_experts_per_tok
|
| 485 |
self.num_keys = int(math.sqrt(self.num_experts))
|
| 486 |
|
| 487 |
+
# router gate for retrieval experts
|
| 488 |
+
self.router_gate = nn.Linear(self.hidden_dim, self.num_keys * 2)
|
|
|
|
| 489 |
|
| 490 |
# experts
|
| 491 |
self.down_embed = nn.Embedding(self.num_experts, self.hidden_dim)
|
| 492 |
self.up_embed = nn.Embedding(self.num_experts, self.hidden_dim)
|
| 493 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
def forward(
|
| 495 |
self,
|
| 496 |
hidden_states: torch.Tensor,
|
|
|
|
| 498 |
) -> torch.Tensor:
|
| 499 |
bsz, seq_len, _ = hidden_states.shape
|
| 500 |
|
| 501 |
+
# get routing weights with router gate
|
| 502 |
+
routing_weights = self.router_gate(hidden_states).view(2, bsz * seq_len, -1)
|
|
|
|
| 503 |
|
| 504 |
# get experts with the highest routing weights
|
| 505 |
+
(scores_x, scores_y), (indices_x, indices_y) = [w.topk(self.num_keys, dim=-1) for w in routing_weights]
|
| 506 |
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
|
| 507 |
+
all_indices = indices_x.unsqueeze(-1) * self.num_keys + indices_y.unsqueeze(-2)
|
| 508 |
+
all_scores = all_scores.view(*all_scores.shape[:-2], -1)
|
| 509 |
+
all_indices = all_indices.view(*all_indices.shape[:-2], -1)
|
| 510 |
+
scores, indices = all_scores.topk(self.top_k, dim=-1)
|
| 511 |
+
down_embed = self.down_embed(indices)
|
|
|
|
| 512 |
up_embed = self.up_embed(indices)
|
| 513 |
|
| 514 |
# mix experts states with cross domain states
|
| 515 |
+
experts_weights = torch.matmul(down_embed, hidden_states.view(bsz * seq_len, -1, 1)).view(bsz * seq_len, -1)
|
| 516 |
experts_weights = self.act_fn(experts_weights) * scores.softmax(dim=-1)
|
| 517 |
experts_states = torch.matmul(experts_weights.view(bsz * seq_len, 1, -1), up_embed).view(bsz, seq_len, -1)
|
| 518 |
hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
|
| 519 |
+
hidden_states = hidden_states + experts_states
|
| 520 |
return hidden_states
|
| 521 |
|
| 522 |
|