| | import os |
| | import spacy |
| | import torch |
| | from torch.utils.data import DataLoader |
| | from transformers import AutoTokenizer |
| | from .utils import get_idxs_from_text |
| | import streamlit as st |
| | from annotated_text import annotated_text |
| | from .nugget_model_utils import CustomRobertaWithPOS |
| | from .event_nugget_predict import get_event_nuggets |
| | from .realis_model_utils import get_entity_for_realis_from_idx, tokenize_and_align_labels_with_pos_ner_realis |
| | from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset |
| |
|
| | event_nugget_list = ['B-Phishing', |
| | 'I-Phishing', |
| | 'O', |
| | 'B-DiscoverVulnerability', |
| | 'B-Ransom', |
| | 'I-Ransom', |
| | 'B-Databreach', |
| | 'I-DiscoverVulnerability', |
| | 'B-PatchVulnerability', |
| | 'I-PatchVulnerability', |
| | 'I-Databreach'] |
| |
|
| | realis_list = ["O", "Generic", "Other", "Actual"] |
| |
|
| | os.environ["TOKENIZERS_PARALLELISM"] = "true" |
| |
|
| |
|
| |
|
| | def find_dep_depth(token): |
| | depth = 0 |
| | current_token = token |
| | while current_token.head != current_token: |
| | depth += 1 |
| | current_token = current_token.head |
| | return min(depth, 16) |
| |
|
| |
|
| | nlp = spacy.load('en_core_web_sm') |
| |
|
| | pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"] |
| | ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"] |
| | dep_spacy_tag_list = list(nlp.get_pipe("parser").labels) |
| |
|
| | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| |
|
| | model_checkpoint = "ehsanaghaei/SecureBERT" |
| | tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | Function: create_dataloader(text_input) |
| | Description: This function prepares a DataLoader for processing text input, including tokenization and alignment of labels. |
| | Inputs: |
| | - text_input: The input text to be processed. |
| | Output: |
| | - dataloader: A DataLoader for the tokenized and batched text data. |
| | - tokenized_dataset_ner: The tokenized dataset used for training. |
| | """ |
| | def create_dataloader(model_nugget, text_input): |
| |
|
| | event_nuggets = get_event_nuggets(model_nugget, text_input) |
| | doc = nlp(text_input) |
| |
|
| | content_as_words_emdash = [tok.text for tok in doc] |
| | content_as_words_emdash = [word.replace("``", '"').replace("''", '"').replace("$", "") for word in content_as_words_emdash] |
| | content_idx_dict = get_idxs_from_text(text_input, content_as_words_emdash) |
| |
|
| | data = [] |
| |
|
| | words = [] |
| | nugget_ner_tags = [] |
| |
|
| | pos_spacy = [tok.pos_ for tok in doc] |
| | ner_spacy = [ent.ent_iob_ + "-" + ent.ent_type_ if ent.ent_iob_ != "O" else ent.ent_iob_ for ent in doc] |
| | dep_spacy = [tok.dep_ for tok in doc] |
| | depth_spacy = [find_dep_depth(tok) for tok in doc] |
| |
|
| | for content_dict in content_idx_dict: |
| | start_idx, end_idx = content_dict["start_idx"], content_dict["end_idx"] |
| | entity = get_entity_for_realis_from_idx(start_idx, end_idx, event_nuggets) |
| | words.append(content_dict["word"]) |
| | nugget_ner_tags.append(entity) |
| |
|
| |
|
| | content_token_len = len(tokenizer(words, truncation=False, is_split_into_words=True)["input_ids"]) |
| | if content_token_len > tokenizer.model_max_length: |
| | no_split = (content_token_len // tokenizer.model_max_length) + 2 |
| | split_len = (len(words) // no_split) + 1 |
| |
|
| | last_id = 0 |
| | threshold = split_len |
| |
|
| | for id, token in enumerate(words): |
| | if token == "." and id > threshold: |
| | data.append( |
| | { |
| | "tokens" : words[last_id : id + 1], |
| | "ner_tags" : nugget_ner_tags[last_id : id + 1], |
| | "pos_spacy" : pos_spacy[last_id : id + 1], |
| | "ner_spacy" : ner_spacy[last_id : id + 1], |
| | "dep_spacy" : dep_spacy[last_id : id + 1], |
| | "depth_spacy" : depth_spacy[last_id : id + 1], |
| | } |
| | ) |
| | last_id = id + 1 |
| | threshold += split_len |
| | data.append({"tokens" : words[last_id : ], |
| | "ner_tags" : nugget_ner_tags[last_id : ], |
| | "pos_spacy" : pos_spacy[last_id : ], |
| | "ner_spacy" : ner_spacy[last_id : ], |
| | "dep_spacy" : dep_spacy[last_id : ], |
| | "depth_spacy" : depth_spacy[last_id : ]}) |
| | else: |
| | data.append( |
| | { |
| | "tokens" : words, |
| | "ner_tags" : nugget_ner_tags, |
| | "pos_spacy" : pos_spacy, |
| | "ner_spacy" : ner_spacy, |
| | "dep_spacy" : dep_spacy, |
| | "depth_spacy" : depth_spacy |
| | } |
| | ) |
| |
|
| |
|
| | ner_features = Features({'tokens' : Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), |
| | 'ner_tags' : Sequence(feature=ClassLabel(num_classes=len(event_nugget_list), names=event_nugget_list, names_file=None, id=None), length=-1, id=None), |
| | 'pos_spacy' : Sequence(feature=ClassLabel(num_classes=len(pos_spacy_tag_list), names=pos_spacy_tag_list, names_file=None, id=None), length=-1, id=None), |
| | 'ner_spacy' : Sequence(feature=ClassLabel(num_classes=len(ner_spacy_tag_list), names=ner_spacy_tag_list, names_file=None, id=None), length=-1, id=None), |
| | 'dep_spacy' : Sequence(feature=ClassLabel(num_classes=len(dep_spacy_tag_list), names=dep_spacy_tag_list, names_file=None, id=None), length=-1, id=None), |
| | 'depth_spacy' : Sequence(feature=ClassLabel(num_classes=17, names= list(range(17)), names_file=None, id=None), length=-1, id=None) |
| | }) |
| |
|
| | dataset = Dataset.from_list(data, features=ner_features) |
| | tokenized_dataset_ner = dataset.map(tokenize_and_align_labels_with_pos_ner_realis, fn_kwargs={'tokenizer' : tokenizer, 'ner_names' : event_nugget_list}, batched=True, load_from_cache_file=False) |
| | tokenized_dataset_ner = tokenized_dataset_ner.with_format("torch") |
| |
|
| | tokenized_dataset_ner = tokenized_dataset_ner.remove_columns("tokens") |
| |
|
| | batch_size = 4 |
| | dataloader = DataLoader(tokenized_dataset_ner, batch_size=batch_size) |
| | return dataloader, tokenized_dataset_ner |
| |
|
| | """ |
| | Function: predict(dataloader) |
| | Description: This function performs inference on a given DataLoader using a trained model and returns the predicted labels. |
| | Inputs: |
| | - dataloader: A DataLoader containing input data for prediction. |
| | Output: |
| | - predicted_label: A tensor containing the predicted labels for the input data. |
| | """ |
| | def predict(dataloader): |
| | predicted_label = [] |
| | for batch in dataloader: |
| | with torch.no_grad(): |
| | logits = model_realis(**batch) |
| |
|
| | batch_predicted_label = logits.argmax(-1) |
| | predicted_label.append(batch_predicted_label) |
| | return torch.cat(predicted_label, dim=-1) |
| |
|
| | """ |
| | Function: show_annotations(text_input) |
| | Description: This function displays annotated event nuggets in the provided input text using the Streamlit library. |
| | Inputs: |
| | - text_input: The input text containing event nuggets to be annotated and displayed. |
| | Output: |
| | - An interactive display of annotated event nuggets within the input text. |
| | """ |
| | def show_annotations(text_input): |
| | st.title("Event Realis") |
| |
|
| | dataloader, tokenized_dataset_ner = create_dataloader(text_input) |
| | predicted_label = predict(dataloader) |
| |
|
| | for idx, labels in enumerate(predicted_label): |
| | token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]] |
| |
|
| | tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True) |
| | tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens] |
| |
|
| | text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask]) |
| | idxs = get_idxs_from_text(text, tokens) |
| |
|
| | labels = labels[token_mask] |
| |
|
| | annotated_text_list = [] |
| | last_label = "" |
| | cumulative_tokens = "" |
| | last_id = 0 |
| |
|
| | for idx, label in zip(idxs, labels): |
| | to_label = realis_list[label] |
| | label_short = to_label.split("-")[1] if "-" in to_label else to_label |
| | if last_label == label_short: |
| | cumulative_tokens += text[last_id : idx["end_idx"]] |
| | last_id = idx["end_idx"] |
| | else: |
| | if last_label != "": |
| | if last_label == "O": |
| | annotated_text_list.append(cumulative_tokens) |
| | else: |
| | annotated_text_list.append((cumulative_tokens, last_label)) |
| | last_label = label_short |
| | cumulative_tokens = idx["word"] |
| | last_id = idx["end_idx"] |
| | if last_label == "O": |
| | annotated_text_list.append(cumulative_tokens) |
| | else: |
| | annotated_text_list.append((cumulative_tokens, last_label)) |
| | annotated_text(annotated_text_list) |
| |
|
| | """ |
| | Function: get_event_realis(text_input) |
| | Description: This function extracts predicted event realis (event modality) from the provided input text. |
| | Inputs: |
| | - text_input: The input text containing event realis to be extracted. |
| | Output: |
| | - predicted_event_realis: A list of dictionaries, each representing an extracted event realis with start and end offsets, |
| | realis type, and text content. |
| | """ |
| | def get_event_realis(text_input): |
| | dataloader, tokenized_dataset_ner = create_dataloader(text_input) |
| | predicted_label = predict(dataloader) |
| |
|
| | predicted_event_realis = [] |
| | text_length = 0 |
| | for idx, labels in enumerate(predicted_label): |
| | token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]] |
| |
|
| | tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True) |
| | tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens] |
| |
|
| | text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask]) |
| | idxs = get_idxs_from_text(text_input[text_length : ], tokens) |
| |
|
| | labels = labels[token_mask] |
| |
|
| | start_idx = 0 |
| | end_idx = 0 |
| | last_label = "" |
| |
|
| | for idx, label in zip(idxs, labels): |
| | to_label = realis_list[label] |
| | label_split = to_label |
| | |
| | if label_split == last_label: |
| | end_idx = idx["end_idx"] |
| | else: |
| | if text_input[start_idx : end_idx] != "" and last_label != "O": |
| | predicted_event_realis.append( |
| | { |
| | "startOffset" : text_length + start_idx, |
| | "endOffset" : text_length + end_idx, |
| | "realis" : last_label, |
| | "text" : text_input[text_length + start_idx : text_length + end_idx] |
| | } |
| | ) |
| | start_idx = idx["start_idx"] |
| | end_idx = idx["start_idx"] + len(idx["word"]) |
| | last_label = label_split |
| | text_length += idx["end_idx"] |
| | return predicted_event_realis |