Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import torch | |
| from transformers import AutoTokenizer, BertForSequenceClassification | |
| # Load DNABERT tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True) | |
| model = BertForSequenceClassification.from_pretrained("zhihan1996/DNABERT-2-117M") | |
| # Mutation classes (example mapping β update based on your fine-tuning) | |
| mutation_map = { | |
| 0: "No Mutation", | |
| 1: "SNV", | |
| 2: "Insertion", | |
| 3: "Deletion" | |
| } | |
| # Simulates mutation detection using DNABERT | |
| def analyze_sequences(input_df): | |
| if input_df is None or input_df.empty: | |
| return pd.DataFrame(columns=["Sequence", "Predicted Mutation", "Confidence Score"]) | |
| results = [] | |
| for _, row in input_df.iterrows(): | |
| seq = row['DNA_Sequence'] | |
| # Tokenize and run inference | |
| inputs = tokenizer(seq, return_tensors="pt", padding=True, truncation=True) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| predicted_class = torch.argmax(logits, dim=1).item() | |
| confidence = float(torch.softmax(logits, dim=1)[0][predicted_class].item()) | |
| # Map prediction to mutation type | |
| mutation = mutation_map.get(predicted_class, "Unknown") | |
| results.append({ | |
| "Sequence": seq, | |
| "Predicted Mutation": mutation, | |
| "Confidence Score": confidence | |
| }) | |
| return pd.DataFrame(results) | |
| # Loads example data and analyzes it | |
| def load_example_data(): | |
| df = pd.DataFrame({ | |
| "DNA_Sequence": [ | |
| "AGCTAGCTA", | |
| "GATCGATCG", | |
| "TTAGCTAGCT", | |
| "ATGCGTAGC" | |
| ] | |
| }) | |
| return analyze_sequences(df) | |
| # Converts DataFrame to CSV string | |
| def dataframe_to_csv(df): | |
| if df is None or df.empty: | |
| return "" | |
| csv_buffer = StringIO() | |
| df.to_csv(csv_buffer, index=False) | |
| return csv_buffer.getvalue() | |
| # Generate mutation statistics summary and chart | |
| def get_mutation_stats(result_df): | |
| if result_df is None or result_df.empty: | |
| return "No data available.", None | |
| # Count mutations | |
| mutation_counts = result_df["Predicted Mutation"].value_counts() | |
| summary_text = "π Mutation Statistics:\n" | |
| for mutation, count in mutation_counts.items(): | |
| summary_text += f"- {mutation}: {count}\n" | |
| # Create bar chart | |
| chart = gr.BarPlot( | |
| mutation_counts.reset_index(), | |
| x="Predicted Mutation", | |
| y="count", | |
| title="Mutation Frequency", | |
| color="Predicted Mutation", | |
| tooltip=["Predicted Mutation", "count"], | |
| vertical=False, | |
| height=200 | |
| ) | |
| return summary_text, chart | |
| # Unified function to process and return all outputs | |
| def process_and_get_stats(file=None): | |
| if file is not None: | |
| result_df = analyze_sequences(file) | |
| else: | |
| result_df = load_example_data() | |
| summary, chart = get_mutation_stats(result_df) | |
| return result_df, summary, chart | |
| # Gradio Interface | |
| with gr.Blocks(theme="default") as demo: | |
| gr.Markdown(""" | |
| # 𧬠MutateX β Liquid Biopsy Mutation Detection Tool | |
| Upload a CSV file with DNA sequences to simulate mutation detection. | |
| *Developed by [GradSyntax](https://www.gradsyntax.com )* | |
| """) | |
| with gr.Row(equal_height=True): | |
| upload_btn = gr.File(label="π Upload CSV File", file_types=[".csv"]) | |
| example_btn = gr.Button("π§ͺ Load Example Data") | |
| output_table = gr.DataFrame( | |
| label="Analysis Results", | |
| headers=["Sequence", "Predicted Mutation", "Confidence Score"] | |
| ) | |
| stats_text = gr.Textbox(label="Mutation Statistics Summary") | |
| stats_chart = gr.Plot(label="Mutation Frequency Chart") | |
| download_btn = gr.File(label="β¬οΈ Download Results as CSV") | |
| # Function calls | |
| upload_btn.upload(fn=process_and_get_stats, inputs=upload_btn, outputs=[output_table, stats_text, stats_chart]) | |
| example_btn.click(fn=process_and_get_stats, inputs=None, outputs=[output_table, stats_text, stats_chart]) | |
| download_btn.upload(fn=dataframe_to_csv, inputs=output_table, outputs=download_btn) | |
| demo.launch() |