Text Retrieval
Transformers
Safetensors
sentence-transformers
English
kpr-bert
feature-extraction
custom_code
Instructions to use knowledgeable-ai/kpr-bge-base-en with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use knowledgeable-ai/kpr-bge-base-en with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("knowledgeable-ai/kpr-bge-base-en", trust_remote_code=True, dtype="auto") - sentence-transformers
How to use knowledgeable-ai/kpr-bge-base-en with sentence-transformers:
from sentence_transformers import SentenceTransformer model = SentenceTransformer("knowledgeable-ai/kpr-bge-base-en", trust_remote_code=True) sentences = [ "The weather is lovely today.", "It's so sunny outside!", "He drove to the stadium." ] embeddings = model.encode(sentences) similarities = model.similarity(embeddings, embeddings) print(similarities.shape) # [3, 3] - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| import csv | |
| import json | |
| import os | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import NamedTuple | |
| import numpy as np | |
| import torch | |
| import spacy | |
| from marisa_trie import Trie | |
| from transformers import BatchEncoding, BertTokenizer, PreTrainedTokenizerBase | |
| NONE_ID = "<None>" | |
| class Mention: | |
| kb_id: str | None | |
| text: str | |
| start: int | |
| end: int | |
| link_count: int | None | |
| total_link_count: int | None | |
| doc_count: int | None | |
| def span(self) -> tuple[int, int]: | |
| return self.start, self.end | |
| def link_prob(self) -> float | None: | |
| if self.doc_count is None or self.total_link_count is None: | |
| return None | |
| elif self.doc_count > 0: | |
| return min(1.0, self.total_link_count / self.doc_count) | |
| else: | |
| return 0.0 | |
| def prior_prob(self) -> float | None: | |
| if self.link_count is None or self.total_link_count is None: | |
| return None | |
| elif self.total_link_count > 0: | |
| return min(1.0, self.link_count / self.total_link_count) | |
| else: | |
| return 0.0 | |
| def __repr__(self): | |
| return f"<Mention {self.text} -> {self.kb_id}>" | |
| def get_tokenizer(language: str) -> spacy.tokenizer.Tokenizer: | |
| language_obj = spacy.blank(language) | |
| return language_obj.tokenizer | |
| class DictionaryEntityLinker: | |
| def __init__( | |
| self, | |
| name_trie: Trie, | |
| kb_id_trie: Trie, | |
| data: np.ndarray, | |
| offsets: np.ndarray, | |
| max_mention_length: int, | |
| case_sensitive: bool, | |
| min_link_prob: float | None, | |
| min_prior_prob: float | None, | |
| min_link_count: int | None, | |
| ): | |
| self._name_trie = name_trie | |
| self._kb_id_trie = kb_id_trie | |
| self._data = data | |
| self._offsets = offsets | |
| self._max_mention_length = max_mention_length | |
| self._case_sensitive = case_sensitive | |
| self._min_link_prob = min_link_prob | |
| self._min_prior_prob = min_prior_prob | |
| self._min_link_count = min_link_count | |
| self._tokenizer = get_tokenizer("en") | |
| def load( | |
| data_dir: str, | |
| min_link_prob: float | None = None, | |
| min_prior_prob: float | None = None, | |
| min_link_count: int | None = None, | |
| ) -> "DictionaryEntityLinker": | |
| data = np.load(os.path.join(data_dir, "data.npy")) | |
| offsets = np.load(os.path.join(data_dir, "offsets.npy")) | |
| name_trie = Trie() | |
| name_trie.load(os.path.join(data_dir, "name.trie")) | |
| kb_id_trie = Trie() | |
| kb_id_trie.load(os.path.join(data_dir, "kb_id.trie")) | |
| with open(os.path.join(data_dir, "config.json")) as config_file: | |
| config = json.load(config_file) | |
| if min_link_prob is None: | |
| min_link_prob = config.get("min_link_prob", None) | |
| if min_prior_prob is None: | |
| min_prior_prob = config.get("min_prior_prob", None) | |
| if min_link_count is None: | |
| min_link_count = config.get("min_link_count", None) | |
| return DictionaryEntityLinker( | |
| name_trie=name_trie, | |
| kb_id_trie=kb_id_trie, | |
| data=data, | |
| offsets=offsets, | |
| max_mention_length=config["max_mention_length"], | |
| case_sensitive=config["case_sensitive"], | |
| min_link_prob=min_link_prob, | |
| min_prior_prob=min_prior_prob, | |
| min_link_count=min_link_count, | |
| ) | |
| def detect_mentions(self, text: str) -> list[Mention]: | |
| tokens = self._tokenizer(text) | |
| end_offsets = frozenset(token.idx + len(token) for token in tokens) | |
| if not self._case_sensitive: | |
| text = text.lower() | |
| ret = [] | |
| cur = 0 | |
| for token in tokens: | |
| start = token.idx | |
| if cur > start: | |
| continue | |
| for prefix in sorted( | |
| self._name_trie.prefixes(text[start : start + self._max_mention_length]), | |
| key=len, | |
| reverse=True, | |
| ): | |
| end = start + len(prefix) | |
| if end in end_offsets: | |
| matched = False | |
| mention_idx = self._name_trie[prefix] | |
| data_start, data_end = self._offsets[mention_idx : mention_idx + 2] | |
| for item in self._data[data_start:data_end]: | |
| if item.size == 4: | |
| kb_idx, link_count, total_link_count, doc_count = item | |
| elif item.size == 1: | |
| (kb_idx,) = item | |
| link_count, total_link_count, doc_count = None, None, None | |
| else: | |
| raise ValueError("Unexpected data array format") | |
| mention = Mention( | |
| kb_id=self._kb_id_trie.restore_key(kb_idx), | |
| text=prefix, | |
| start=start, | |
| end=end, | |
| link_count=link_count, | |
| total_link_count=total_link_count, | |
| doc_count=doc_count, | |
| ) | |
| if item.size == 1 or ( | |
| mention.link_prob >= self._min_link_prob | |
| and mention.prior_prob >= self._min_prior_prob | |
| and mention.link_count >= self._min_link_count | |
| ): | |
| ret.append(mention) | |
| matched = True | |
| if matched: | |
| cur = end | |
| break | |
| return ret | |
| def detect_mentions_batch(self, texts: list[str]) -> list[list[Mention]]: | |
| return [self.detect_mentions(text) for text in texts] | |
| def save(self, data_dir: str) -> None: | |
| """ | |
| Save the entity linker data to the specified directory. | |
| Args: | |
| data_dir: Directory to save the entity linker data | |
| """ | |
| os.makedirs(data_dir, exist_ok=True) | |
| # Save numpy arrays | |
| np.save(os.path.join(data_dir, "data.npy"), self._data) | |
| np.save(os.path.join(data_dir, "offsets.npy"), self._offsets) | |
| # Save tries | |
| self._name_trie.save(os.path.join(data_dir, "name.trie")) | |
| self._kb_id_trie.save(os.path.join(data_dir, "kb_id.trie")) | |
| # Save configuration | |
| with open(os.path.join(data_dir, "config.json"), "w") as config_file: | |
| json.dump( | |
| { | |
| "max_mention_length": self._max_mention_length, | |
| "case_sensitive": self._case_sensitive, | |
| "min_link_prob": self._min_link_prob, | |
| "min_prior_prob": self._min_prior_prob, | |
| "min_link_count": self._min_link_count, | |
| }, | |
| config_file, | |
| ) | |
| def load_tsv_entity_vocab(file_path: str) -> dict[str, int]: | |
| vocab = {} | |
| with open(file_path, "r", encoding="utf-8") as file: | |
| reader = csv.reader(file, delimiter="\t") | |
| for row in reader: | |
| vocab[row[0]] = int(row[1]) | |
| return vocab | |
| def save_tsv_entity_vocab(file_path: str, entity_vocab: dict[str, int]) -> None: | |
| """ | |
| Save entity vocabulary to a TSV file. | |
| Args: | |
| file_path: Path to save the entity vocabulary | |
| entity_vocab: Entity vocabulary to save | |
| """ | |
| os.makedirs(os.path.dirname(file_path), exist_ok=True) | |
| with open(file_path, "w", encoding="utf-8") as f: | |
| writer = csv.writer(f, delimiter="\t") | |
| for entity_id, idx in entity_vocab.items(): | |
| writer.writerow([entity_id, idx]) | |
| class _Entity(NamedTuple): | |
| entity_id: int | |
| start: int | |
| end: int | |
| def length(self) -> int: | |
| return self.end - self.start | |
| def preprocess_text( | |
| text: str, | |
| mentions: list[Mention] | None, | |
| title: str | None, | |
| title_mentions: list[Mention] | None, | |
| tokenizer: PreTrainedTokenizerBase, | |
| entity_vocab: dict[str, int], | |
| ) -> dict[str, list[int]]: | |
| tokens = [] | |
| entity_ids = [] | |
| entity_position_ids = [] | |
| if title is not None: | |
| if title_mentions is None: | |
| title_mentions = [] | |
| title_tokens, title_entities = _tokenize_text_with_mentions(title, title_mentions, tokenizer, entity_vocab) | |
| tokens += title_tokens + [tokenizer.sep_token] | |
| for entity in title_entities: | |
| entity_ids.append(entity.entity_id) | |
| entity_position_ids.append(list(range(entity.start, entity.end))) | |
| if mentions is None: | |
| mentions = [] | |
| entity_offset = len(tokens) | |
| text_tokens, text_entities = _tokenize_text_with_mentions(text, mentions, tokenizer, entity_vocab) | |
| tokens += text_tokens | |
| for entity in text_entities: | |
| entity_ids.append(entity.entity_id) | |
| entity_position_ids.append(list(range(entity.start + entity_offset, entity.end + entity_offset))) | |
| input_ids = tokenizer.convert_tokens_to_ids(tokens) | |
| return { | |
| "input_ids": input_ids, | |
| "entity_ids": entity_ids, | |
| "entity_position_ids": entity_position_ids, | |
| } | |
| def _tokenize_text_with_mentions( | |
| text: str, | |
| mentions: list[Mention], | |
| tokenizer: PreTrainedTokenizerBase, | |
| entity_vocab: dict[str, int], | |
| ) -> tuple[list[str], list[_Entity]]: | |
| """ | |
| Tokenize text while preserving mention boundaries and mapping entities. | |
| Args: | |
| text: Input text to tokenize | |
| mentions: List of detected mentions in the text | |
| tokenizer: Pre-trained tokenizer to use for tokenization | |
| entity_vocab: Mapping from entity KB IDs to entity vocabulary indices | |
| Returns: | |
| Tuple containing: | |
| - List of tokens from the tokenized text | |
| - List of _Entity objects with entity IDs and token positions | |
| """ | |
| target_mentions = [mention for mention in mentions if mention.kb_id is not None and mention.kb_id in entity_vocab] | |
| split_char_positions = {mention.start for mention in target_mentions} | {mention.end for mention in target_mentions} | |
| tokens: list[str] = [] | |
| cur = 0 | |
| char_to_token_mapping = {} | |
| for char_position in sorted(split_char_positions): | |
| target_text = text[cur:char_position] | |
| tokens += tokenizer.tokenize(target_text) | |
| char_to_token_mapping[char_position] = len(tokens) | |
| cur = char_position | |
| tokens += tokenizer.tokenize(text[cur:]) | |
| entities = [ | |
| _Entity( | |
| entity_vocab[mention.kb_id], | |
| char_to_token_mapping[mention.start], | |
| char_to_token_mapping[mention.end], | |
| ) | |
| for mention in target_mentions | |
| ] | |
| return tokens, entities | |
| class KPRBertTokenizer(BertTokenizer): | |
| vocab_files_names = { | |
| **BertTokenizer.vocab_files_names, # Include the parent class files (vocab.txt) | |
| "entity_linker_data_file": "entity_linker/data.npy", | |
| "entity_linker_offsets_file": "entity_linker/offsets.npy", | |
| "entity_linker_name_trie_file": "entity_linker/name.trie", | |
| "entity_linker_kb_id_trie_file": "entity_linker/kb_id.trie", | |
| "entity_linker_config_file": "entity_linker/config.json", | |
| "entity_vocab_file": "entity_vocab.tsv", | |
| "entity_embeddings_file": "entity_embeddings.npy", | |
| } | |
| model_input_names = [ | |
| "input_ids", | |
| "token_type_ids", | |
| "attention_mask", | |
| "entity_ids", | |
| "entity_position_ids", | |
| ] | |
| def __init__( | |
| self, | |
| vocab_file, | |
| entity_linker_data_file: str, | |
| entity_vocab_file: str, | |
| entity_embeddings_file: str | None = None, | |
| *args, | |
| **kwargs, | |
| ): | |
| super().__init__(vocab_file=vocab_file, *args, **kwargs) | |
| entity_linker_dir = str(Path(entity_linker_data_file).parent) | |
| self.entity_linker = DictionaryEntityLinker.load(entity_linker_dir) | |
| self.entity_to_id = load_tsv_entity_vocab(entity_vocab_file) | |
| self.id_to_entity = {v: k for k, v in self.entity_to_id.items()} | |
| self.entity_embeddings = None | |
| if entity_embeddings_file: | |
| # Use memory-mapped loading for large embeddings | |
| self.entity_embeddings = np.load(entity_embeddings_file, mmap_mode="r") | |
| if self.entity_embeddings.shape[0] != len(self.entity_to_id): | |
| raise ValueError( | |
| f"Entity embeddings shape {self.entity_embeddings.shape[0]} does not match " | |
| f"the number of entities {len(self.entity_to_id)}. " | |
| "Make sure `embeddings.py` and `entity_vocab.tsv` are consistent." | |
| ) | |
| def _preprocess_text(self, text: str, **kwargs) -> dict[str, list[int | list[int]]]: | |
| mentions = self.entity_linker.detect_mentions(text) | |
| model_inputs = preprocess_text( | |
| text=text, | |
| mentions=mentions, | |
| title=None, | |
| title_mentions=None, | |
| tokenizer=self, | |
| entity_vocab=self.entity_to_id, | |
| ) | |
| # Prepare the inputs for the model | |
| # This will add special tokens or truncate the input when specified in kwargs | |
| # We exclude "return_tensors" from kwargs | |
| # to avoid issues in passing the data to BatchEncoding outside this method | |
| prepared_inputs = self.prepare_for_model( | |
| model_inputs["input_ids"], | |
| **{k: v for k, v in kwargs.items() if k != "return_tensors"}, | |
| ) | |
| model_inputs.update(prepared_inputs) | |
| # Account for special tokens | |
| if kwargs.get("add_special_tokens", True): | |
| if prepared_inputs["input_ids"][0] != self.cls_token_id: | |
| raise ValueError( | |
| "We assume that the input IDs start with the [CLS] token with add_special_tokens = True." | |
| ) | |
| # Shift the entity position IDs by 1 to account for the [CLS] token | |
| model_inputs["entity_position_ids"] = [ | |
| [pos + 1 for pos in positions] for positions in model_inputs["entity_position_ids"] | |
| ] | |
| # If there is no entities in the text, we output padding entity for the model | |
| if not model_inputs["entity_ids"]: | |
| model_inputs["entity_ids"] = [0] # The padding entity id is 0 | |
| model_inputs["entity_position_ids"] = [[0]] | |
| # Count the number of special tokens at the end of the input | |
| num_special_tokens_at_end = 0 | |
| input_ids = prepared_inputs["input_ids"] | |
| if isinstance(input_ids, torch.Tensor): | |
| input_ids = input_ids.tolist() | |
| for input_id in input_ids[::-1]: | |
| if int(input_id) not in { | |
| self.sep_token_id, | |
| self.pad_token_id, | |
| self.cls_token_id, | |
| }: | |
| break | |
| num_special_tokens_at_end += 1 | |
| # Remove entities that are not in truncated input | |
| max_effective_pos = len(model_inputs["input_ids"]) - num_special_tokens_at_end | |
| entity_indices_to_keep = list() | |
| for i, position_ids in enumerate(model_inputs["entity_position_ids"]): | |
| if len(position_ids) > 0 and max(position_ids) < max_effective_pos: | |
| entity_indices_to_keep.append(i) | |
| model_inputs["entity_ids"] = [model_inputs["entity_ids"][i] for i in entity_indices_to_keep] | |
| model_inputs["entity_position_ids"] = [model_inputs["entity_position_ids"][i] for i in entity_indices_to_keep] | |
| if self.entity_embeddings is not None: | |
| model_inputs["entity_embeds"] = self.entity_embeddings[model_inputs["entity_ids"]].astype(np.float32) | |
| return model_inputs | |
| def __call__(self, text: str | list[str], **kwargs) -> BatchEncoding: | |
| for unsupported_arg in ["text_pair", "text_target", "text_pair_target"]: | |
| if unsupported_arg in kwargs: | |
| raise ValueError( | |
| f"Argument '{unsupported_arg}' is not supported by {self.__class__.__name__}. " | |
| "This tokenizer only supports single text inputs. " | |
| ) | |
| if isinstance(text, str): | |
| processed_inputs = self._preprocess_text(text, **kwargs) | |
| return BatchEncoding( | |
| processed_inputs, | |
| tensor_type=kwargs.get("return_tensors", None), | |
| prepend_batch_axis=True, | |
| ) | |
| processed_inputs_list: list[dict[str, list[int]]] = [self._preprocess_text(t, **kwargs) for t in text] | |
| collated_inputs = { | |
| key: [item[key] for item in processed_inputs_list] for key in processed_inputs_list[0].keys() | |
| } | |
| if kwargs.get("padding"): | |
| collated_inputs = self.pad( | |
| collated_inputs, | |
| padding=kwargs["padding"], | |
| max_length=kwargs.get("max_length"), | |
| pad_to_multiple_of=kwargs.get("pad_to_multiple_of"), | |
| return_attention_mask=kwargs.get("return_attention_mask"), | |
| verbose=kwargs.get("verbose", True), | |
| ) | |
| # Pad entity ids | |
| max_num_entities = max(len(ids) for ids in collated_inputs["entity_ids"]) | |
| for entity_ids in collated_inputs["entity_ids"]: | |
| entity_ids += [0] * (max_num_entities - len(entity_ids)) | |
| # Pad entity position ids | |
| flattened_entity_length = [ | |
| len(ids) for ids_list in collated_inputs["entity_position_ids"] for ids in ids_list | |
| ] | |
| max_entity_token_length = max(flattened_entity_length) if flattened_entity_length else 0 | |
| for entity_position_ids_list in collated_inputs["entity_position_ids"]: | |
| # pad entity_position_ids to max_entity_token_length | |
| for entity_position_ids in entity_position_ids_list: | |
| entity_position_ids += [0] * (max_entity_token_length - len(entity_position_ids)) | |
| # pad to max_num_entities | |
| entity_position_ids_list += [[0 for _ in range(max_entity_token_length)]] * ( | |
| max_num_entities - len(entity_position_ids_list) | |
| ) | |
| # Pad entity embeddings | |
| if "entity_embeds" in collated_inputs: | |
| for i in range(len(collated_inputs["entity_embeds"])): | |
| collated_inputs["entity_embeds"][i] = np.pad( | |
| collated_inputs["entity_embeds"][i], | |
| pad_width=( | |
| ( | |
| 0, | |
| max_num_entities - len(collated_inputs["entity_embeds"][i]), | |
| ), | |
| (0, 0), | |
| ), | |
| mode="constant", | |
| constant_values=0, | |
| ) | |
| return BatchEncoding(collated_inputs, tensor_type=kwargs.get("return_tensors", None)) | |
| def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]: | |
| os.makedirs(save_directory, exist_ok=True) | |
| saved_files = list(super().save_vocabulary(save_directory, filename_prefix)) | |
| # Save entity linker data | |
| entity_linker_save_dir = str( | |
| Path(save_directory) / Path(self.vocab_files_names["entity_linker_data_file"]).parent | |
| ) | |
| self.entity_linker.save(entity_linker_save_dir) | |
| for file_name in self.vocab_files_names.values(): | |
| if file_name.startswith("entity_linker/"): | |
| saved_files.append(file_name) | |
| # Save entity vocabulary | |
| entity_vocab_path = str(Path(save_directory) / self.vocab_files_names["entity_vocab_file"]) | |
| save_tsv_entity_vocab(entity_vocab_path, self.entity_to_id) | |
| saved_files.append(self.vocab_files_names["entity_vocab_file"]) | |
| if self.entity_embeddings is not None: | |
| entity_embeddings_path = str(Path(save_directory) / self.vocab_files_names["entity_embeddings_file"]) | |
| np.save(entity_embeddings_path, self.entity_embeddings) | |
| saved_files.append(self.vocab_files_names["entity_embeddings_file"]) | |
| return tuple(saved_files) | |