Spaces:
Build error
Build error
| import torch | |
| def expand_inputs_for_generation( | |
| input_ids, | |
| expand_size=1, | |
| is_encoder_decoder=False, | |
| attention_mask=None, | |
| encoder_outputs=None, | |
| **model_kwargs, | |
| ): | |
| expanded_return_idx = ( | |
| torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) | |
| ) | |
| input_ids = input_ids.index_select(0, expanded_return_idx) | |
| if "token_type_ids" in model_kwargs: | |
| token_type_ids = model_kwargs["token_type_ids"] | |
| model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) | |
| if attention_mask is not None: | |
| model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) | |
| model_kwargs["image_attention_mask"] = model_kwargs["image_attention_mask"].index_select( | |
| 0, expanded_return_idx | |
| ) | |
| model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(0, expanded_return_idx) | |
| if is_encoder_decoder: | |
| if encoder_outputs is None: | |
| raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") | |
| encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( | |
| 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) | |
| ) | |
| model_kwargs["encoder_outputs"] = encoder_outputs | |
| return input_ids, model_kwargs | |
| def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): | |
| # must have this key set to at least None | |
| model_kwargs["past_key_values"] = model_kwargs.get("past_key_values", None) | |
| # update past | |
| if "past_key_values" in outputs: | |
| model_kwargs["past"] = outputs.past_key_values | |
| elif "mems" in outputs: | |
| model_kwargs["past"] = outputs.mems | |
| elif "past_buckets_states" in outputs: | |
| model_kwargs["past"] = outputs.past_buckets_states | |
| else: | |
| model_kwargs["past"] = None | |
| # update token_type_ids with last value | |
| if "token_type_ids" in model_kwargs: | |
| token_type_ids = model_kwargs["token_type_ids"] | |
| model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) | |
| # update attention masks | |
| if not is_encoder_decoder: | |
| if "attention_mask" in model_kwargs: | |
| attention_mask = model_kwargs["attention_mask"] | |
| model_kwargs["attention_mask"] = torch.cat( | |
| [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 | |
| ) | |
| if "image_attention_mask" in model_kwargs: | |
| image_attention_mask = model_kwargs["image_attention_mask"] | |
| last_mask = image_attention_mask[:, -1, :].unsqueeze(1) | |
| model_kwargs["image_attention_mask"] = last_mask | |
| return model_kwargs | |
| def prepare_inputs_for_generation(input_ids, past=None, **kwargs): | |
| token_type_ids = kwargs.get("token_type_ids", None) | |
| # only last token for inputs_ids if past is defined in kwargs | |
| if past: | |
| input_ids = input_ids[:, -1].unsqueeze(-1) | |
| if token_type_ids is not None: | |
| token_type_ids = token_type_ids[:, -1].unsqueeze(-1) | |
| attention_mask = kwargs.get("attention_mask", None) | |
| position_ids = kwargs.get("position_ids", None) | |
| if attention_mask is not None and position_ids is None: | |
| # create position_ids on the fly for batch generation | |
| position_ids = attention_mask.long().cumsum(-1) - 1 | |
| position_ids.masked_fill_(attention_mask == 0, 1) | |
| if past: | |
| position_ids = position_ids[:, -1].unsqueeze(-1) | |
| pixel_values = kwargs.get("pixel_values", None) | |
| image_attention_mask = kwargs.get("image_attention_mask", None) | |
| if pixel_values is None or image_attention_mask is None: | |
| raise ValueError("pixel values and image attention mask cannot be None") | |
| return { | |
| "input_ids": input_ids, | |
| "past_key_values": past, | |
| "use_cache": kwargs.get("use_cache"), | |
| "position_ids": position_ids, | |
| "attention_mask": attention_mask, | |
| "token_type_ids": token_type_ids, | |
| "pixel_values": pixel_values, | |
| "image_attention_mask": image_attention_mask, | |
| } | |