jordan0811 commited on
Commit
39a7193
·
verified ·
1 Parent(s): 634c27d

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ fig_accuracy_latency.png filter=lfs diff=lfs merge=lfs -text
37
+ zebra.jpg filter=lfs diff=lfs merge=lfs -text
bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
fig_accuracy_latency.png ADDED

Git LFS Details

  • SHA256: 3518a573b474cb0dee8b08ac87925251cf446eb3744a2beb1c807c5ecb5ef840
  • Pointer size: 131 Bytes
  • Size of remote file: 437 kB
run_axmodel.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import axengine as ort
4
+ import torch
5
+ from torchvision.transforms import Normalize, Compose, InterpolationMode, ToTensor, Resize, CenterCrop
6
+ from tokenizer import SimpleTokenizer
7
+ import argparse
8
+
9
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
10
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
11
+
12
+
13
+ def image_transform_v2():
14
+ resolution = 256
15
+ resize_size = resolution
16
+ centercrop_size = resolution
17
+ mean = OPENAI_DATASET_MEAN
18
+ std = OPENAI_DATASET_STD
19
+ aug_list = [
20
+ Resize(
21
+ resize_size,
22
+ interpolation=InterpolationMode.BICUBIC,
23
+ ),
24
+ CenterCrop(centercrop_size),
25
+ ToTensor(),
26
+ Normalize(mean=mean, std=std)
27
+ ]
28
+ preprocess = Compose(aug_list)
29
+ return preprocess
30
+
31
+
32
+ def image_transform_v1():
33
+ resolution = 256
34
+ resize_size = resolution
35
+ centercrop_size = resolution
36
+ aug_list = [
37
+ Resize(
38
+ resize_size,
39
+ interpolation=InterpolationMode.BILINEAR,
40
+ ),
41
+ CenterCrop(centercrop_size),
42
+ ToTensor(),
43
+ ]
44
+ preprocess = Compose(aug_list)
45
+ return preprocess
46
+
47
+
48
+ def softmax(x, axis=-1):
49
+ """
50
+ 对 numpy 数组在指定维度上应用 softmax 函数
51
+
52
+ 参数:
53
+ x: numpy 数组,输入数据
54
+ axis: 计算 softmax 的维度,默认为最后一个维度 (-1)
55
+
56
+ 返回:
57
+ 经过 softmax 处理的 numpy 数组,与输入形状相同
58
+ """
59
+ # 减去最大值以防止数值溢出(数值稳定化)
60
+ e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
61
+ # 计算每个元素的指数与所在维度总和的比值
62
+ return e_x / np.sum(e_x, axis=axis, keepdims=True)
63
+
64
+ if __name__ == "__main__":
65
+ parser = argparse.ArgumentParser()
66
+ parser.add_argument("-ie", "--image_encoder_path", type=str, default="./mobileclip2_s4_image_encoder.axmodel",
67
+ help="image encoder axmodel path")
68
+ parser.add_argument("-te", "--text_encoder_path", type=str, default="./mobileclip2_s4_text_encoder.axmodel",
69
+ help="text encoder axmodel path")
70
+ parser.add_argument("-i", "--image", type=str, default="./zebra.jpg",
71
+ help="input image path")
72
+ parser.add_argument("-t", "--class_text", type=str, nargs='+', default=["a zebra", "a dog", "two zebras"],
73
+ help='List of captions, e.g.: "a zebra" "a dog" "two zebras"')
74
+ args = parser.parse_args()
75
+
76
+ image_encoder_path = args.image_encoder_path
77
+ text_encoder_path = args.text_encoder_path
78
+ # NOTICE: 使用v1的预处理,v2的预处理方式在pulsar2中量化误差比较大
79
+ preprocess = image_transform_v1()
80
+ tokenizer = SimpleTokenizer(context_length=77)
81
+
82
+ image = preprocess(Image.open(args.image).convert('RGB')).unsqueeze(0)
83
+ text = tokenizer(args.class_text)
84
+ text = text.to(torch.int32)
85
+
86
+ onnx_image_encoder = ort.InferenceSession(image_encoder_path)
87
+ onnx_text_encoder = ort.InferenceSession(text_encoder_path)
88
+
89
+ image_features = onnx_image_encoder.run(["unnorm_image_features"],{"image":np.array(image)})[0]
90
+ # text_features = []
91
+ # for i in range(text.shape[0]):
92
+ # text_feature = onnx_text_encoder.run(["unnorm_text_features"],{"text":np.array([text[i]])})[0]
93
+ # text_features.append(text_feature)
94
+ # text_features = np.array([t[0] for t in text_features])
95
+ text_features = onnx_text_encoder.run(["unnorm_text_features"], {"text": text.numpy()})[0]
96
+ image_features /= np.linalg.norm(image_features, ord=2, axis=-1, keepdims=True)
97
+ text_features /= np.linalg.norm(text_features, ord=2, axis=-1, keepdims=True)
98
+
99
+ text_probs = softmax(100.0 * image_features @ text_features.T)
100
+
101
+ print("Label probs:", text_probs)
tokenizer.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP tokenizer
2
+
3
+ Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ import gzip
6
+ import html
7
+ import os
8
+ import random
9
+ import string
10
+ from functools import lru_cache, partial
11
+ from typing import Callable, List, Optional, Union, Dict
12
+ import warnings
13
+
14
+ import ftfy
15
+ import numpy as np
16
+ import regex as re
17
+ import torch
18
+
19
+ # https://stackoverflow.com/q/62691279
20
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
+ _nltk_init = False
22
+
23
+ DEFAULT_CONTEXT_LENGTH = 77 # default context length for OpenAI CLIP
24
+
25
+
26
+ @lru_cache()
27
+ def default_bpe():
28
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
29
+
30
+
31
+ @lru_cache()
32
+ def bytes_to_unicode():
33
+ """
34
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
35
+ The reversible bpe codes work on unicode strings.
36
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
37
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
38
+ This is a significant percentage of your normal, say, 32K bpe vocab.
39
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
40
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
41
+ """
42
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
43
+ cs = bs[:]
44
+ n = 0
45
+ for b in range(2**8):
46
+ if b not in bs:
47
+ bs.append(b)
48
+ cs.append(2**8+n)
49
+ n += 1
50
+ cs = [chr(n) for n in cs]
51
+ return dict(zip(bs, cs))
52
+
53
+
54
+ def get_pairs(word):
55
+ """Return set of symbol pairs in a word.
56
+ Word is represented as tuple of symbols (symbols being variable-length strings).
57
+ """
58
+ pairs = set()
59
+ prev_char = word[0]
60
+ for char in word[1:]:
61
+ pairs.add((prev_char, char))
62
+ prev_char = char
63
+ return pairs
64
+
65
+
66
+ def basic_clean(text):
67
+ text = ftfy.fix_text(text)
68
+ text = html.unescape(html.unescape(text))
69
+ return text.strip()
70
+
71
+
72
+ def whitespace_clean(text):
73
+ text = " ".join(text.split())
74
+ text = text.strip()
75
+ return text
76
+
77
+
78
+ def _clean_canonicalize(x):
79
+ # basic, remove whitespace, remove punctuation, lower case
80
+ return canonicalize_text(basic_clean(x))
81
+
82
+
83
+ def _clean_lower(x):
84
+ # basic, remove whitespace, lower case
85
+ return whitespace_clean(basic_clean(x)).lower()
86
+
87
+
88
+ def _clean_whitespace(x):
89
+ # basic, remove whitespace
90
+ return whitespace_clean(basic_clean(x))
91
+
92
+
93
+ def get_clean_fn(type: str):
94
+ if type == 'canonicalize':
95
+ return _clean_canonicalize
96
+ elif type == 'lower':
97
+ return _clean_lower
98
+ elif type == 'whitespace':
99
+ return _clean_whitespace
100
+ else:
101
+ assert False, f"Invalid clean function ({type})."
102
+
103
+
104
+ def canonicalize_text(
105
+ text,
106
+ *,
107
+ keep_punctuation_exact_string=None,
108
+ trans_punctuation: dict = str.maketrans("", "", string.punctuation),
109
+ ):
110
+ """Returns canonicalized `text` (lowercase and punctuation removed).
111
+
112
+ From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
113
+
114
+ Args:
115
+ text: string to be canonicalized.
116
+ keep_punctuation_exact_string: If provided, then this exact string kept.
117
+ For example providing '{}' will keep any occurrences of '{}' (but will
118
+ still remove '{' and '}' that appear separately).
119
+ """
120
+ text = text.replace("_", " ")
121
+ if keep_punctuation_exact_string:
122
+ text = keep_punctuation_exact_string.join(
123
+ part.translate(trans_punctuation)
124
+ for part in text.split(keep_punctuation_exact_string)
125
+ )
126
+ else:
127
+ text = text.translate(trans_punctuation)
128
+ text = text.lower()
129
+ text = " ".join(text.split())
130
+ return text.strip()
131
+
132
+
133
+ class SimpleTokenizer(object):
134
+ def __init__(
135
+ self,
136
+ bpe_path: str = default_bpe(),
137
+ additional_special_tokens: Optional[List[str]] = None,
138
+ context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
139
+ clean: str = 'lower',
140
+ reduction_mask: str = ''
141
+ ):
142
+ self.byte_encoder = bytes_to_unicode()
143
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
144
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
145
+ merges = merges[1:49152-256-2+1]
146
+ merges = [tuple(merge.split()) for merge in merges]
147
+ vocab = list(bytes_to_unicode().values())
148
+ vocab = vocab + [v+'</w>' for v in vocab]
149
+ for merge in merges:
150
+ vocab.append(''.join(merge))
151
+ special_tokens = ['<start_of_text>', '<end_of_text>']
152
+ if additional_special_tokens:
153
+ special_tokens += additional_special_tokens
154
+ vocab.extend(special_tokens)
155
+ self.encoder = dict(zip(vocab, range(len(vocab))))
156
+ self.decoder = {v: k for k, v in self.encoder.items()}
157
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
158
+ self.cache = {t:t for t in special_tokens}
159
+ special = "|".join(special_tokens)
160
+ self.pat = re.compile(
161
+ special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
162
+ re.IGNORECASE,
163
+ )
164
+ self.vocab_size = len(self.encoder)
165
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
166
+ self.sot_token_id = self.all_special_ids[0]
167
+ self.eot_token_id = self.all_special_ids[1]
168
+ self.context_length = context_length
169
+ self.clean_fn = get_clean_fn(clean)
170
+ self.reduction_fn = get_reduction_mask_fn(reduction_mask) if reduction_mask else None
171
+
172
+ def bpe(self, token):
173
+ if token in self.cache:
174
+ return self.cache[token]
175
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
176
+ pairs = get_pairs(word)
177
+
178
+ if not pairs:
179
+ return token+'</w>'
180
+
181
+ while True:
182
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
183
+ if bigram not in self.bpe_ranks:
184
+ break
185
+ first, second = bigram
186
+ new_word = []
187
+ i = 0
188
+ while i < len(word):
189
+ try:
190
+ j = word.index(first, i)
191
+ new_word.extend(word[i:j])
192
+ i = j
193
+ except Exception:
194
+ new_word.extend(word[i:])
195
+ break
196
+
197
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
198
+ new_word.append(first+second)
199
+ i += 2
200
+ else:
201
+ new_word.append(word[i])
202
+ i += 1
203
+ new_word = tuple(new_word)
204
+ word = new_word
205
+ if len(word) == 1:
206
+ break
207
+ else:
208
+ pairs = get_pairs(word)
209
+ word = ' '.join(word)
210
+ self.cache[token] = word
211
+ return word
212
+
213
+ def encode(self, text):
214
+ bpe_tokens = []
215
+ text = self.clean_fn(text)
216
+ for token in re.findall(self.pat, text):
217
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
218
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
219
+ return bpe_tokens
220
+
221
+ def decode(self, tokens):
222
+ text = ''.join([self.decoder[token] for token in tokens])
223
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
224
+ return text
225
+
226
+ def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.LongTensor:
227
+ """ Returns the tokenized representation of given input string(s)
228
+
229
+ Parameters
230
+ ----------
231
+ texts : Union[str, List[str]]
232
+ An input string or a list of input strings to tokenize
233
+ context_length : int
234
+ The context length to use; all CLIP models use 77 as the context length
235
+
236
+ Returns
237
+ -------
238
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
239
+ """
240
+ if isinstance(texts, str):
241
+ texts = [texts]
242
+
243
+ context_length = context_length or self.context_length
244
+ assert context_length, 'Please set a valid context length'
245
+
246
+ if self.reduction_fn is not None:
247
+ # use reduction strategy for tokenize if set, otherwise default to truncation below
248
+ return self.reduction_fn(
249
+ texts,
250
+ context_length=context_length,
251
+ sot_token_id=self.sot_token_id,
252
+ eot_token_id=self.eot_token_id,
253
+ encode_fn=self.encode,
254
+ )
255
+
256
+ all_tokens = [[self.sot_token_id] + self.encode(text) + [self.eot_token_id] for text in texts]
257
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
258
+
259
+ for i, tokens in enumerate(all_tokens):
260
+ if len(tokens) > context_length:
261
+ tokens = tokens[:context_length] # Truncate
262
+ tokens[-1] = self.eot_token_id
263
+ result[i, :len(tokens)] = torch.tensor(tokens)
264
+
265
+ return result
266
+
267
+
268
+ _tokenizer = SimpleTokenizer()
269
+
270
+
271
+ def decode(output_ids: torch.Tensor):
272
+ output_ids = output_ids.cpu().numpy()
273
+ return _tokenizer.decode(output_ids)
274
+
275
+
276
+ def tokenize(texts: Union[str, List[str]], context_length: int = DEFAULT_CONTEXT_LENGTH) -> torch.LongTensor:
277
+ return _tokenizer(texts, context_length=context_length)
278
+
279
+
280
+ def random_mask_tokenize(
281
+ texts: Union[str, List[str]],
282
+ context_length: int,
283
+ sot_token_id: int,
284
+ eot_token_id: int,
285
+ encode_fn: Callable,
286
+ shuffle: bool = False,
287
+ ):
288
+ all_tokens = [encode_fn(text) for text in texts]
289
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
290
+
291
+ for i, tokens in enumerate(all_tokens):
292
+ tokens = torch.tensor(tokens)
293
+ num_tokens = len(tokens)
294
+ if num_tokens > context_length - 2: # 2 for sot and eot token
295
+ num_keep = context_length - 2
296
+ indices = torch.randperm(len(tokens))
297
+ indices = indices[:num_keep]
298
+ if not shuffle:
299
+ indices = indices.msort()
300
+ tokens = tokens[indices]
301
+ num_tokens = num_keep
302
+ result[i, 0] = sot_token_id
303
+ result[i, 1:num_tokens + 1] = tokens
304
+ result[i, num_tokens + 1] = eot_token_id
305
+
306
+ return result
307
+
308
+
309
+ def simple_mask_tokenize(
310
+ texts: Union[str, List[str]],
311
+ context_length: int,
312
+ sot_token_id: int,
313
+ eot_token_id: int,
314
+ encode_fn: Callable,
315
+ ):
316
+ all_tokens = [encode_fn(text) for text in texts]
317
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
318
+
319
+ for i, tokens in enumerate(all_tokens):
320
+ num_tokens = len(tokens)
321
+ if num_tokens > context_length - 2: # 2 for sot and eot token
322
+ num_keep = context_length - 2
323
+ start_index = random.randint(0, num_tokens - num_keep) # high is incl
324
+ tokens = tokens[start_index: start_index + num_keep]
325
+ tokens = [sot_token_id] + tokens + [eot_token_id]
326
+ result[i, :len(tokens)] = torch.tensor(tokens)
327
+
328
+ return result
329
+
330
+
331
+ def syntax_mask_tokenize(
332
+ texts: Union[str, List[str]],
333
+ context_length: int,
334
+ sot_token_id: int,
335
+ eot_token_id: int,
336
+ encode_fn: Callable,
337
+ ) -> torch.LongTensor:
338
+ """ Returns the tokenized representation of given input string(s).
339
+ Apply syntax masking before tokenize.
340
+ """
341
+ import nltk
342
+ global _nltk_init
343
+ if not _nltk_init:
344
+ # run them for the first time
345
+ nltk.download('punkt')
346
+ nltk.download('averaged_perceptron_tagger')
347
+ _nltk_init = True
348
+
349
+ def get_order(x):
350
+ if x.startswith('NN'):
351
+ return 1
352
+ elif x.startswith('JJ'):
353
+ return 2
354
+ elif x.startswith('VB'):
355
+ return 3
356
+ else:
357
+ return 4
358
+
359
+ # syntax masking
360
+ new_texts = []
361
+ for text in texts:
362
+ list_tokens = nltk.tokenize.word_tokenize(text)
363
+ pos_tags = nltk.pos_tag(list_tokens)
364
+ # sample the words by get_order method
365
+ order_list = [get_order(tag) for _, tag in pos_tags]
366
+ sorted_ids = np.argsort(np.array(order_list))
367
+ sampled_ids = sorted(sorted_ids[:context_length - 2]) # need 2 slots for sot and eot tokens
368
+ sampled_tokens = np.take(np.array(list_tokens), sampled_ids, axis=0) # sample the tokens
369
+
370
+ new_text = ''
371
+ for token in sampled_tokens:
372
+ new_text = new_text + str(token) + ' '
373
+ new_text = new_text.strip()
374
+ new_texts.append(new_text)
375
+ texts = new_texts
376
+
377
+ all_tokens = [[sot_token_id] + encode_fn(text) + [eot_token_id] for text in texts]
378
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
379
+
380
+ for i, tokens in enumerate(all_tokens):
381
+ # still need first truncate because some words produces two tokens
382
+ if len(tokens) > context_length:
383
+ tokens = tokens[:context_length] # Truncate
384
+ tokens[-1] = eot_token_id
385
+ result[i, :len(tokens)] = torch.tensor(tokens)
386
+
387
+ return result
388
+
389
+
390
+ def get_reduction_mask_fn(type: str):
391
+ """ Choose strategy for dropping (masking) tokens to achieve target context length"""
392
+ assert type in ('simple', 'random', 'shuffle', 'syntax')
393
+ if type == 'simple':
394
+ return simple_mask_tokenize # randomly select block [start:end]
395
+ elif type == 'random':
396
+ return random_mask_tokenize # randomly drop tokens (keep order)
397
+ elif type == 'shuffle':
398
+ return partial(random_mask_tokenize, shuffle=True) # randomly drop tokens (shuffle order)
399
+ elif type == 'syntax':
400
+ return syntax_mask_tokenize # randomly drop prioritized by syntax
401
+ else:
402
+ assert False, F'Unknown type {type}.'
403
+
404
+
405
+ class HFTokenizer:
406
+ """HuggingFace tokenizer wrapper with support for custom tokenization modes"""
407
+
408
+ def __init__(
409
+ self,
410
+ tokenizer_name: str,
411
+ context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
412
+ clean: str = 'whitespace',
413
+ strip_sep_token: bool = False,
414
+ language: Optional[str] = None,
415
+ cache_dir: Optional[str] = None,
416
+ tokenizer_mode: Optional[str] = None, # None, 'clips'
417
+ **kwargs
418
+ ):
419
+ self.tokenizer_mode = tokenizer_mode or ''
420
+ self.context_length = context_length
421
+ self.clean_fn = get_clean_fn(clean)
422
+ self.strip_sep_token = strip_sep_token
423
+
424
+ # NOTE: Left as example of loading custom tokenizer from file for experimentation
425
+ # if self.tokenizer_mode == 'bert_clips':
426
+ # self.special_tokens = {
427
+ # "bos_token": 1,
428
+ # "eos_token": 2,
429
+ # "cls_token": 101,
430
+ # "pad_token": 0
431
+ # }
432
+ #
433
+ # # For BERT CLIPS mode with vocab file
434
+ # from tokenizers import BertWordPieceTokenizer
435
+ # if tokenizer_name.startswith('hf-hub:'):
436
+ # from huggingface_hub import hf_hub_download
437
+ # # Format: hf-hub:repo_id/filename
438
+ # repo_url = tokenizer_name[7:]
439
+ # parts = repo_url.split('/')
440
+ # filename = parts[-1]
441
+ # repo_id = '/'.join(parts[:-1])
442
+ # vocab_file = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir)
443
+ # self.tokenizer = BertWordPieceTokenizer(lowercase=True)
444
+ # self.tokenizer = self.tokenizer.from_file(vocab_file)
445
+ # else:
446
+ # # Assume tokenizer_name is a local path to a vocab file
447
+ # self.tokenizer = BertWordPieceTokenizer(lowercase=True)
448
+ # self.tokenizer = self.tokenizer.from_file(tokenizer_name)
449
+
450
+ # Standard HuggingFace tokenizer initialization
451
+ from transformers import AutoTokenizer
452
+ self.tokenizer = AutoTokenizer.from_pretrained(
453
+ tokenizer_name,
454
+ cache_dir=cache_dir,
455
+ **kwargs
456
+ )
457
+
458
+ # Set language function if available
459
+ set_lang_fn = getattr(self.tokenizer, 'set_src_lang_special_tokens', None)
460
+ if callable(set_lang_fn):
461
+ self.set_lang_fn = set_lang_fn
462
+ if language is not None:
463
+ self.set_language(language)
464
+
465
+ def save_pretrained(self, dest):
466
+ self.tokenizer.save_pretrained(dest)
467
+
468
+ def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor:
469
+ # same cleaning as for default tokenizer, except lowercasing
470
+ # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
471
+ if isinstance(texts, str):
472
+ texts = [texts]
473
+
474
+ context_length = context_length or self.context_length
475
+ assert context_length, 'Please set a valid context length in class init or call.'
476
+
477
+ texts = [self.clean_fn(text) for text in texts]
478
+
479
+ # Handle different tokenization modes
480
+ if self.tokenizer_mode == 'clips':
481
+ return self._clips_tokenize(texts, context_length)
482
+ else:
483
+ # Standard tokenization
484
+ input_ids = self.tokenizer.batch_encode_plus(
485
+ texts,
486
+ return_tensors='pt',
487
+ max_length=context_length,
488
+ padding='max_length',
489
+ truncation=True,
490
+ ).input_ids
491
+
492
+ if self.strip_sep_token:
493
+ input_ids = torch.where(
494
+ input_ids == self.tokenizer.sep_token_id,
495
+ torch.zeros_like(input_ids),
496
+ input_ids,
497
+ )
498
+
499
+ return input_ids
500
+
501
+ def set_language(self, src_lang):
502
+ if hasattr(self, 'set_lang_fn'):
503
+ self.set_lang_fn(src_lang)
504
+ else:
505
+ warnings.warn('Cannot set language for the tokenizer.')
506
+
507
+ def _clips_tokenize(self, texts: List[str], context_length: int) -> torch.Tensor:
508
+ """Use standard HF tokenizer but apply custom post-processing"""
509
+ # Use standard tokenizer without special tokens - we'll add our own
510
+ encoded_outputs = self.tokenizer.batch_encode_plus(
511
+ texts,
512
+ add_special_tokens=False,
513
+ padding=False,
514
+ truncation=False,
515
+ return_tensors=None
516
+ )
517
+
518
+ encoded = []
519
+ for tokens in encoded_outputs["input_ids"]:
520
+ tokens = tokens[:context_length - 3] # Leave room for special tokens
521
+ tokens = [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
522
+ encoded.append(tokens)
523
+
524
+ # Create result tensor and handle padding + class token
525
+ result = torch.zeros(len(encoded), context_length, dtype=torch.long)
526
+ for i, tokens in enumerate(encoded):
527
+ padded_tokens = self._pad_and_add_class_token(
528
+ tokens,
529
+ max_length=context_length,
530
+ pad_token_id=self.tokenizer.pad_token_id,
531
+ cls_token_id=self.tokenizer.cls_token_id,
532
+ )
533
+ result[i, :len(padded_tokens)] = torch.tensor(padded_tokens)
534
+
535
+ return result
536
+
537
+ def _pad_and_add_class_token(
538
+ self,
539
+ tokens: List[int],
540
+ max_length: int,
541
+ pad_token_id: int = 0,
542
+ cls_token_id: int = 101,
543
+ ) -> List[int]:
544
+ """ Add padding with class token at the end """
545
+ if len(tokens) > max_length - 1:
546
+ tokens = tokens[:max_length - 1]
547
+
548
+ # Add padding to reach max_length-1
549
+ if len(tokens) < max_length - 1:
550
+ tokens = tokens + [pad_token_id] * (max_length - 1 - len(tokens))
551
+
552
+ # Add class token at the end
553
+ tokens = tokens + [cls_token_id]
554
+ return tokens
555
+
556
+
557
+ class SigLipTokenizer:
558
+ """HuggingFace tokenizer wrapper for SigLIP T5 compatible sentencepiece vocabs
559
+
560
+ NOTE: this is not needed in normal library use, but is used to import new sentencepiece tokenizers
561
+ into OpenCLIP. Leaving code here in case future models use new tokenizers.
562
+ """
563
+ VOCAB_FILES = {
564
+ # english, vocab_size=32_000
565
+ "c4-en": "http://storage.googleapis.com/t5-data/vocabs/cc_en.32000/sentencepiece.model",
566
+ # used in multilingual models (mT5, PaLI), vocab_size=250_000
567
+ "mc4": "http://storage.googleapis.com/t5-data/vocabs/mc4.250000.100extra/sentencepiece.model",
568
+ # used in SigLIP2 models, vocab_size=256000
569
+ "gemma": "http://storage.googleapis.com/big_vision/gemma_tokenizer.model",
570
+ }
571
+
572
+ def __init__(
573
+ self,
574
+ tokenizer_name: str,
575
+ context_length: Optional[int] = 64,
576
+ ):
577
+ if 'gemma' in tokenizer_name:
578
+ from transformers import GemmaTokenizerFast
579
+ tokenizer_cls = partial(
580
+ GemmaTokenizerFast, padding_side='right', add_bos_token=False, add_eos_token=True)
581
+ else:
582
+ from transformers import T5TokenizerFast
583
+ tokenizer_cls = partial(T5TokenizerFast, extra_ids=0)
584
+
585
+ if tokenizer_name in self.VOCAB_FILES:
586
+ # FIXME temporary hack?
587
+ import tempfile
588
+ import fsspec
589
+ vocab_file = self.VOCAB_FILES[tokenizer_name]
590
+ with tempfile.NamedTemporaryFile('wb') as dst:
591
+ with fsspec.open(vocab_file, 'rb') as src:
592
+ dst.write(src.read())
593
+ self.tokenizer = tokenizer_cls(dst.name, legacy=False)
594
+ else:
595
+ self.tokenizer = tokenizer_cls(tokenizer_name, legacy=False)
596
+
597
+ self.tokenizer.pad_token_id = 0 if 'gemma' in tokenizer_name else 1
598
+ self.tokenizer.eos_token_id = 1
599
+ self.context_length = context_length
600
+
601
+ def save_pretrained(self, dest):
602
+ self.tokenizer.save_pretrained(dest)
603
+
604
+ def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor:
605
+ # same cleaning as for default tokenizer, except lowercasing
606
+ # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
607
+ if isinstance(texts, str):
608
+ texts = [texts]
609
+
610
+ context_length = context_length or self.context_length
611
+ assert context_length, 'Please set a valid context length in class init or call.'
612
+
613
+ texts = [canonicalize_text(basic_clean(text)) for text in texts]
614
+ output = self.tokenizer(
615
+ texts,
616
+ return_tensors='pt',
617
+ max_length=context_length,
618
+ padding='max_length',
619
+ truncation=True,
620
+ )
621
+ return output.input_ids
zebra.jpg ADDED

Git LFS Details

  • SHA256: c97b8b6f195541932de28e6a79bacd569d65d29e0a32f04274885acc0cab505c
  • Pointer size: 131 Bytes
  • Size of remote file: 246 kB