Commit
·
4b66519
1
Parent(s):
6170b43
add assertions and docs
Browse files- tokenizer.py +11 -2
tokenizer.py
CHANGED
|
@@ -5,9 +5,15 @@ import warnings
|
|
| 5 |
|
| 6 |
|
| 7 |
class JinaTokenizer(RobertaTokenizer):
|
| 8 |
-
def __init__(self, *args,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
super().__init__(*args, **kwargs)
|
| 10 |
-
self.task_type_vocab_size = task_type_vocab_size
|
| 11 |
|
| 12 |
def __call__(self, *args, task_type=None, **kwargs):
|
| 13 |
batch_encoding = super().__call__(*args, **kwargs)
|
|
@@ -50,6 +56,9 @@ class JinaTokenizer(RobertaTokenizer):
|
|
| 50 |
|
| 51 |
def apply_task_type(m, x):
|
| 52 |
x = torch.tensor(x)
|
|
|
|
|
|
|
|
|
|
| 53 |
return m * x if len(x.shape) == 0 else m * x[:, None]
|
| 54 |
|
| 55 |
if isinstance(batch_encoding['input_ids'], torch.Tensor):
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
class JinaTokenizer(RobertaTokenizer):
|
| 8 |
+
def __init__(self, *args, **kwargs):
|
| 9 |
+
"""
|
| 10 |
+
JinaTokenizer extends the RobertaTokenizer class to include task_type_ids in
|
| 11 |
+
the batch encoding.
|
| 12 |
+
The task_type_ids are used to pass instruction information to the model.
|
| 13 |
+
A task_type should either be an integer or a sequence of integers with the same
|
| 14 |
+
length as the batch size.
|
| 15 |
+
"""
|
| 16 |
super().__init__(*args, **kwargs)
|
|
|
|
| 17 |
|
| 18 |
def __call__(self, *args, task_type=None, **kwargs):
|
| 19 |
batch_encoding = super().__call__(*args, **kwargs)
|
|
|
|
| 56 |
|
| 57 |
def apply_task_type(m, x):
|
| 58 |
x = torch.tensor(x)
|
| 59 |
+
assert (
|
| 60 |
+
len(x.shape) == 0 or x.shape[0] == m.shape[0]
|
| 61 |
+
), 'The shape of task_type does not match the size of the batch.'
|
| 62 |
return m * x if len(x.shape) == 0 else m * x[:, None]
|
| 63 |
|
| 64 |
if isinstance(batch_encoding['input_ids'], torch.Tensor):
|