Update modeling_phi.py
Browse files- modeling_phi.py +3 -7
modeling_phi.py
CHANGED
|
@@ -294,15 +294,11 @@ class MoE(nn.Module):
|
|
| 294 |
def __init__(
|
| 295 |
self,
|
| 296 |
config: PretrainedConfig,
|
| 297 |
-
num_experts=2,
|
| 298 |
-
num_experts_per_tok=2,
|
| 299 |
-
num_shards=1,
|
| 300 |
-
**kwargs,
|
| 301 |
):
|
| 302 |
super().__init__()
|
| 303 |
-
self.mlp = nn.ModuleList([MLP(config) for i in range(
|
| 304 |
-
self.gate = nn.Linear(config.n_embd,
|
| 305 |
-
self.num_experts_per_tok = num_experts_per_tok
|
| 306 |
|
| 307 |
def forward(self, x):
|
| 308 |
orig_shape = x.shape
|
|
|
|
| 294 |
def __init__(
|
| 295 |
self,
|
| 296 |
config: PretrainedConfig,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
):
|
| 298 |
super().__init__()
|
| 299 |
+
self.mlp = nn.ModuleList([MLP(config) for i in range(config.num_local_experts)])
|
| 300 |
+
self.gate = nn.Linear(config.n_embd, config.num_local_experts, bias=False)
|
| 301 |
+
self.num_experts_per_tok = config.num_experts_per_tok
|
| 302 |
|
| 303 |
def forward(self, x):
|
| 304 |
orig_shape = x.shape
|