# -*- coding: utf-8 -*- # file: omnigenbench_wrapper.py # time: 00:57 27/04/2024 # author: YANG, HENG (杨恒) # github: https://github.com/yangheng95 # huggingface: https://huggingface.co/yangheng # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en # Copyright (C) 2019-2024. All Rights Reserved. import warnings from omnigenbench import OmniTokenizer class Tokenizer(OmniTokenizer): def __init__(self, base_tokenizer=None, u2t=True, add_whitespace=False, **kwargs): super(Tokenizer, self).__init__( base_tokenizer, u2t=u2t, add_whitespace=add_whitespace, **kwargs ) self.metadata["tokenizer_name"] = self.__class__.__name__ def __call__(self, sequence, **kwargs): if self.u2t: sequence = "".join([seq.replace("U", "T").upper() for seq in sequence]) if self.t2u: sequence = "".join([seq.replace("T", "U").upper() for seq in sequence]) if self.add_whitespace: sequence = " ".join(list(sequence)) sequence_tokens = self.tokenize(sequence)[ : kwargs.get("max_length", self.max_length) - 2 ] tokenized_inputs = { "input_ids": [], "attention_mask": [], } bos_id = ( self.base_tokenizer.bos_token_id if self.base_tokenizer.bos_token_id is not None else self.base_tokenizer.cls_token_id ) eos_id = ( self.base_tokenizer.eos_token_id if self.base_tokenizer.eos_token_id is not None else self.base_tokenizer.sep_token_id ) for tokens in sequence_tokens: tokenized_inputs["input_ids"].append( [bos_id] + self.base_tokenizer.convert_tokens_to_ids(tokens) + [eos_id] ) tokenized_inputs["attention_mask"].append( [1] * len(tokenized_inputs["input_ids"][-1]) ) for i, ids in enumerate(tokenized_inputs["input_ids"]): if ids.count(self.base_tokenizer.unk_token_id) / len(ids) > 0.1: warnings.warn( f"Unknown tokens are more than " f"{ids.count(self.base_tokenizer.unk_token_id) / len(ids)}% in the {i}-th sequence, " f"please check the tokenization process." ) max_length = max(len(ids) for ids in tokenized_inputs["input_ids"]) tokenized_inputs = self.base_tokenizer.pad( tokenized_inputs, padding=kwargs.get("padding", "max_length"), max_length=min(max_length, kwargs.get("max_length", 512)), return_attention_mask=kwargs.get("return_attention_mask", True), return_tensors="pt", ) return tokenized_inputs def tokenize(self, sequence, **kwargs): if isinstance(sequence, str): sequences = [sequence] else: sequences = sequence sequence_tokens = [] for i in range(len(sequences)): sequence_tokens.append(list(sequences[i])) return sequence_tokens def encode(self, sequence, **kwargs): return self.base_tokenizer.encode(sequence, **kwargs) def decode(self, sequence, **kwargs): return self.base_tokenizer.decode(sequence, **kwargs) def encode_plus(self, sequence, **kwargs): return self.base_tokenizer.encode_plus(sequence, **kwargs)