Spaces:
Running
Running
| import os | |
| from tqdm import tqdm | |
| import torch | |
| import pandas as pd | |
| import json | |
| import time | |
| import numpy as np | |
| from sklearn.metrics import roc_auc_score, accuracy_score | |
| from networks import create_architecture, count_parameters | |
| from utils.dataset import create_dataloader | |
| from utils.processing import add_processing_arguments | |
| from parser import get_parser | |
| def test(loader, model, settings, device): | |
| model.eval() | |
| start_time = time.time() | |
| # File paths | |
| output_dir = f'./results/{settings.name}/{settings.data_keys}/data/' | |
| os.makedirs(output_dir, exist_ok=True) | |
| csv_filename = os.path.join(output_dir, 'results.csv') | |
| metrics_filename = os.path.join(output_dir, 'metrics.json') | |
| image_results_filename = os.path.join(output_dir, 'image_results.json') | |
| # Collect all results | |
| all_scores = [] | |
| all_labels = [] | |
| all_paths = [] | |
| image_results = [] | |
| # Extract training dataset keys from model name (format: "training_keys_freeze_down" or "training_keys") | |
| # The model name typically contains the training dataset keys used for training | |
| training_dataset_keys = [] | |
| model_name = settings.name | |
| # Remove common suffixes like "_freeze_down" | |
| if '_freeze_down' in model_name: | |
| training_name = model_name.replace('_freeze_down', '') | |
| else: | |
| training_name = model_name | |
| # Split by & to get individual training dataset keys | |
| if '&' in training_name: | |
| training_dataset_keys = training_name.split('&') | |
| else: | |
| training_dataset_keys = [training_name] | |
| # Write CSV header | |
| with open(csv_filename, 'w') as f: | |
| f.write(f"{','.join(['name', 'pro', 'flag'])}\n") | |
| with torch.no_grad(): | |
| with tqdm(loader, unit='batch', mininterval=0.5) as tbatch: | |
| tbatch.set_description(f'Validation') | |
| for data_dict in tbatch: | |
| data = data_dict['img'].to(device) | |
| labels = data_dict['target'].to(device) | |
| paths = data_dict['path'] | |
| scores = model(data).squeeze(1) | |
| # Collect results | |
| for score, label, path in zip(scores, labels, paths): | |
| score_val = score.item() | |
| label_val = label.item() | |
| all_scores.append(score_val) | |
| all_labels.append(label_val) | |
| all_paths.append(path) | |
| image_results.append({ | |
| 'path': path, | |
| 'score': score_val, | |
| 'label': label_val | |
| }) | |
| # Write to CSV (maintain backward compatibility) | |
| with open(csv_filename, 'a') as f: | |
| for score, label, path in zip(scores, labels, paths): | |
| f.write(f"{path}, {score.item()}, {label.item()}\n") | |
| # Calculate metrics | |
| all_scores = np.array(all_scores) | |
| all_labels = np.array(all_labels) | |
| # Convert scores to predictions (threshold at 0, as used in train.py) | |
| predictions = (all_scores > 0).astype(int) | |
| # Calculate overall metrics | |
| total_accuracy = accuracy_score(all_labels, predictions) | |
| # TPR (True Positive Rate) = TP / (TP + FN) = accuracy on fake images (label==1) | |
| fake_mask = all_labels == 1 | |
| if fake_mask.sum() > 0: | |
| tpr = accuracy_score(all_labels[fake_mask], predictions[fake_mask]) | |
| else: | |
| tpr = 0.0 | |
| # Calculate TNR on real images (label==0) in the test set | |
| real_mask = all_labels == 0 | |
| if real_mask.sum() > 0: | |
| # Overall TNR calculated on all real images in the test set | |
| tnr = accuracy_score(all_labels[real_mask], predictions[real_mask]) | |
| else: | |
| tnr = 0.0 | |
| # AUC calculation (needs probabilities, so we'll use sigmoid on scores) | |
| if len(np.unique(all_labels)) > 1: | |
| # Apply sigmoid to convert scores to probabilities | |
| probabilities = torch.sigmoid(torch.tensor(all_scores)).numpy() | |
| auc = roc_auc_score(all_labels, probabilities) | |
| else: | |
| auc = 0.0 | |
| execution_time = time.time() - start_time | |
| # Prepare metrics JSON | |
| metrics = { | |
| 'TPR': float(tpr), | |
| 'TNR': float(tnr), | |
| 'Acc total': float(total_accuracy), | |
| 'AUC': float(auc), | |
| 'execution time': float(execution_time) | |
| } | |
| # Write metrics JSON | |
| with open(metrics_filename, 'w') as f: | |
| json.dump(metrics, f, indent=2) | |
| # Write individual image results JSON | |
| with open(image_results_filename, 'w') as f: | |
| json.dump(image_results, f, indent=2) | |
| print(f'\nMetrics saved to {metrics_filename}') | |
| print(f'Image results saved to {image_results_filename}') | |
| print(f'\nMetrics:') | |
| print(f' TPR: {tpr:.4f}') | |
| print(f' TNR: {tnr:.4f}') | |
| print(f' Accuracy: {total_accuracy:.4f}') | |
| print(f' AUC: {auc:.4f}') | |
| print(f' Execution time: {execution_time:.2f} seconds') | |
| if __name__ == '__main__': | |
| parser = get_parser() | |
| parser = add_processing_arguments(parser) | |
| settings = parser.parse_args() | |
| device = torch.device(settings.device if torch.cuda.is_available() else 'cpu') | |
| test_dataloader = create_dataloader(settings, split='test') | |
| model = create_architecture(settings.arch, pretrained=True, num_classes=1).to(device) | |
| num_parameters = count_parameters(model) | |
| print(f"Arch: {settings.arch} with #parameters {num_parameters}") | |
| load_path = f'./checkpoint/{settings.name}/weights/best.pt' | |
| print('loading the model from %s' % load_path) | |
| model.load_state_dict(torch.load(load_path, map_location=device)['model']) | |
| model.to(device) | |
| test(test_dataloader, model, settings, device) | |