Files changed (4) hide show
  1. README.md +2 -2
  2. config.json +6 -0
  3. configuration_phi4flash.py +11 -5
  4. modeling_phi4flash.py +9 -9
README.md CHANGED
@@ -25,7 +25,7 @@ The model belongs to the Phi-4 model family and supports 64K token context lengt
25
  📚 [Training Codebase](https://github.com/microsoft/ArchScale) <br>
26
  👩‍🍳 [Phi Cookbook](https://github.com/microsoft/PhiCookBook) <br>
27
  🏡 [Phi Portal](https://azure.microsoft.com/en-us/products/phi) <br>
28
- 🚀 vLLM Inference: V0: [PR](https://github.com/vllm-project/vllm/pull/20702) | [Branch](https://github.com/congcongchen123/vllm/tree/congcongchen/phi4-mini-shadow) V1: [PR](https://github.com/vllm-project/vllm/pull/23996) <br>
29
  🖥️ Try It [Azure](https://ai.azure.com/explore/models/Phi-4-mini-flash-reasoning/version/1/registry/azureml-phi-prod) [Nvidia NIM](https://build.nvidia.com/microsoft/phi-4-mini-flash-reasoning)<br>
30
 
31
 
@@ -236,4 +236,4 @@ Benchmark datasets
236
  We evaluate the model with three of the most popular math benchmarks where the strongest reasoning models are competing together. Specifically:
237
  + Math-500: This benchmark consists of 500 challenging math problems designed to test the model's ability to perform complex mathematical reasoning and problem-solving.
238
  + AIME 2024/AIME 2025: The American Invitational Mathematics Examination (AIME) is a highly regarded math competition that features a series of difficult problems aimed at assessing advanced mathematical skills and logical reasoning. We evaluate the models on the problems from both 2024 and the year 2025 examinations.
239
- + GPQA Diamond: The Graduate-Level Google-Proof Q&A (GPQA) Diamond benchmark focuses on evaluating the model's ability to understand and solve a wide range of mathematical questions, including both straightforward calculations and more intricate problem-solving tasks.
 
25
  📚 [Training Codebase](https://github.com/microsoft/ArchScale) <br>
26
  👩‍🍳 [Phi Cookbook](https://github.com/microsoft/PhiCookBook) <br>
27
  🏡 [Phi Portal](https://azure.microsoft.com/en-us/products/phi) <br>
28
+ 🚀 [vLLM Inference](https://github.com/vllm-project/vllm/pull/20702) <br>
29
  🖥️ Try It [Azure](https://ai.azure.com/explore/models/Phi-4-mini-flash-reasoning/version/1/registry/azureml-phi-prod) [Nvidia NIM](https://build.nvidia.com/microsoft/phi-4-mini-flash-reasoning)<br>
30
 
31
 
 
236
  We evaluate the model with three of the most popular math benchmarks where the strongest reasoning models are competing together. Specifically:
237
  + Math-500: This benchmark consists of 500 challenging math problems designed to test the model's ability to perform complex mathematical reasoning and problem-solving.
238
  + AIME 2024/AIME 2025: The American Invitational Mathematics Examination (AIME) is a highly regarded math competition that features a series of difficult problems aimed at assessing advanced mathematical skills and logical reasoning. We evaluate the models on the problems from both 2024 and the year 2025 examinations.
239
+ + GPQA Diamond: The Graduate-Level Google-Proof Q&A (GPQA) Diamond benchmark focuses on evaluating the model's ability to understand and solve a wide range of mathematical questions, including both straightforward calculations and more intricate problem-solving tasks.
config.json CHANGED
@@ -26,6 +26,12 @@
26
  "num_key_value_heads": 20,
27
  "resid_pdrop": 0.0,
28
  "sliding_window": 512,
 
 
 
 
 
 
29
  "torch_dtype": "bfloat16",
30
  "tie_word_embeddings": true,
31
  "transformers_version": "4.46.1",
 
26
  "num_key_value_heads": 20,
27
  "resid_pdrop": 0.0,
28
  "sliding_window": 512,
29
+ "layer_types": [
30
+ "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention",
31
+ "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention", "full_attention", "sliding_attention",
32
+ "full_attention", "full_attention", "full_attention", "full_attention", "full_attention", "full_attention", "full_attention", "full_attention",
33
+ "full_attention", "full_attention", "full_attention", "full_attention", "full_attention", "full_attention", "full_attention", "full_attention"
34
+ ],
35
  "torch_dtype": "bfloat16",
36
  "tie_word_embeddings": true,
37
  "transformers_version": "4.46.1",
configuration_phi4flash.py CHANGED
@@ -112,6 +112,7 @@ class Phi4FlashConfig(PretrainedConfig):
112
  bos_token_id=1,
113
  eos_token_id=2,
114
  sliding_window=2047,
 
115
  mb_per_layer= 2,
116
  mamba_d_state=16,
117
  mamba_d_conv=4,
@@ -141,11 +142,16 @@ class Phi4FlashConfig(PretrainedConfig):
141
  self.use_cache = use_cache
142
  self.rope_theta = rope_theta
143
  self.mb_per_layer = mb_per_layer
144
- self.sliding_window = [
145
- sliding_window if layer_idx < num_hidden_layers // 2 and layer_idx % 2 == 1 else None
146
- for layer_idx in range(num_hidden_layers)
147
- ]
148
 
 
 
 
 
 
 
 
149
  self.mamba_d_state = mamba_d_state
150
  self.mamba_d_conv = mamba_d_conv
151
  self.mamba_expand = mamba_expand
@@ -170,4 +176,4 @@ class Phi4FlashConfig(PretrainedConfig):
170
  else:
171
  layer_block_type = "mamba"
172
  layer_block_types.append(layer_block_type)
173
- return layer_block_types
 
112
  bos_token_id=1,
113
  eos_token_id=2,
114
  sliding_window=2047,
115
+ layer_types=None,
116
  mb_per_layer= 2,
117
  mamba_d_state=16,
118
  mamba_d_conv=4,
 
142
  self.use_cache = use_cache
143
  self.rope_theta = rope_theta
144
  self.mb_per_layer = mb_per_layer
145
+ self.sliding_window = sliding_window
146
+ self.layer_types = layer_types
 
 
147
 
148
+ if self.layer_types is None:
149
+ is_sliding = lambda i: i < num_hidden_layers // 2 and i % 2 == 1
150
+ self.layer_types = [
151
+ "sliding_attention" if is_sliding(layer_idx) else "full_attention"
152
+ for layer_idx in range(num_hidden_layers)
153
+ ]
154
+
155
  self.mamba_d_state = mamba_d_state
156
  self.mamba_d_conv = mamba_d_conv
157
  self.mamba_expand = mamba_expand
 
176
  else:
177
  layer_block_type = "mamba"
178
  layer_block_types.append(layer_block_type)
179
+ return layer_block_types
modeling_phi4flash.py CHANGED
@@ -129,7 +129,7 @@ def _get_cache(
129
  cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache
130
 
131
  if cache_implementation == "sliding_window":
132
- max_cache_len = min(self.config.sliding_window[1], max_cache_len)
133
 
134
  need_new_cache = (
135
  not hasattr(self, "_cache")
@@ -243,7 +243,7 @@ class SambaYCache(Cache):
243
  sliding_cache_shape = (
244
  self.max_batch_size,
245
  self.num_key_value_heads,
246
- min(config.sliding_window[1], max_cache_len),
247
  self.head_dim,
248
  )
249
  conv_cache_shape = (self.max_batch_size, intermediate_size, conv_kernel_size)
@@ -573,7 +573,7 @@ class SambaYFlashAttention2(SambaYAttention):
573
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
574
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
575
 
576
- use_sliding_windows = self.config.sliding_window is not None and self.config.sliding_window[self.layer_idx] is not None
577
 
578
  if past_key_value is not None:
579
 
@@ -710,8 +710,8 @@ class SambaYFlashAttention2(SambaYAttention):
710
  softmax_scale=softmax_scale,
711
  causal=causal,
712
  window_size=(
713
- self.config.sliding_window[self.layer_idx] -1,
714
- self.config.sliding_window[self.layer_idx] -1,
715
  ),
716
  )
717
 
@@ -735,8 +735,8 @@ class SambaYFlashAttention2(SambaYAttention):
735
  softmax_scale=softmax_scale,
736
  causal=causal,
737
  window_size=(
738
- self.config.sliding_window[self.layer_idx] -1,
739
- self.config.sliding_window[self.layer_idx] -1,
740
  ),
741
  )
742
 
@@ -1085,9 +1085,9 @@ class SambaYDecoderLayer(nn.Module):
1085
  residual = residual.to(torch.float32)
1086
  self_attn_weights = None
1087
  else:
1088
- if self.config.sliding_window is not None and self.config.sliding_window[self.layer_idx] is not None and attention_mask is not None: # efficient SDPA and no padding
1089
  if past_key_value is not None and cache_position[0] > 0: # when decoding
1090
- attention_mask = attention_mask[:, -self.config.sliding_window[self.layer_idx]:]
1091
  #hidden_states = self.input_layernorm2(hidden_states.to(dtype=self.input_layernorm2.weight.dtype))
1092
  # Self Attention
1093
  attn_outputs, self_attn_weights, yoco_key_values = self.attn(
 
129
  cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache
130
 
131
  if cache_implementation == "sliding_window":
132
+ max_cache_len = min(self.config.sliding_window, max_cache_len)
133
 
134
  need_new_cache = (
135
  not hasattr(self, "_cache")
 
243
  sliding_cache_shape = (
244
  self.max_batch_size,
245
  self.num_key_value_heads,
246
+ min(config.sliding_window, max_cache_len),
247
  self.head_dim,
248
  )
249
  conv_cache_shape = (self.max_batch_size, intermediate_size, conv_kernel_size)
 
573
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
574
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
575
 
576
+ use_sliding_windows = self.config.sliding_window is not None and self.config.layer_types[self.layer_idx] == "sliding_attention"
577
 
578
  if past_key_value is not None:
579
 
 
710
  softmax_scale=softmax_scale,
711
  causal=causal,
712
  window_size=(
713
+ self.config.sliding_window -1,
714
+ self.config.sliding_window -1,
715
  ),
716
  )
717
 
 
735
  softmax_scale=softmax_scale,
736
  causal=causal,
737
  window_size=(
738
+ self.config.sliding_window -1,
739
+ self.config.sliding_window -1,
740
  ),
741
  )
742
 
 
1085
  residual = residual.to(torch.float32)
1086
  self_attn_weights = None
1087
  else:
1088
+ if self.config.sliding_window is not None and self.config.layer_types[self.layer_idx] == "sliding_attention" and attention_mask is not None: # efficient SDPA and no padding
1089
  if past_key_value is not None and cache_position[0] > 0: # when decoding
1090
+ attention_mask = attention_mask[:, -self.config.sliding_window:]
1091
  #hidden_states = self.input_layernorm2(hidden_states.to(dtype=self.input_layernorm2.weight.dtype))
1092
  # Self Attention
1093
  attn_outputs, self_attn_weights, yoco_key_values = self.attn(