Update app.py
Browse files
app.py
CHANGED
|
@@ -305,86 +305,27 @@ def query(input: str, total_examples: int, near_far_ratio: float = 0.5):
|
|
| 305 |
# Example
|
| 306 |
print("Example: ", query("water bottle", total_examples=4, near_far_ratio=0.5))
|
| 307 |
|
| 308 |
-
# 1. Load dataset
|
| 309 |
-
dataset = load_dataset("cwinkler/patents_green_plastics")
|
| 310 |
-
|
| 311 |
-
# Split into train/test
|
| 312 |
-
dataset = dataset["train"].train_test_split(test_size=0.2, seed=42)
|
| 313 |
-
train_dataset = dataset["train"]
|
| 314 |
-
test_dataset = dataset["test"]
|
| 315 |
-
|
| 316 |
-
# 2. Tokenizer
|
| 317 |
-
model_name = "distilbert-base-uncased"
|
| 318 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 319 |
-
|
| 320 |
-
def preprocess(examples):
|
| 321 |
-
return tokenizer(examples["abstract"], truncation=True, padding="max_length", max_length=128)
|
| 322 |
-
|
| 323 |
-
tokenized = dataset.map(preprocess, batched=True)
|
| 324 |
-
tokenized = tokenized.rename_column("label", "labels")
|
| 325 |
-
tokenized.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
|
| 326 |
-
|
| 327 |
-
train_dataset = tokenized["train"].shuffle(seed=42).select(range(2000)) # subset for CPU
|
| 328 |
-
test_dataset = tokenized["test"]
|
| 329 |
-
|
| 330 |
-
# 3. Base model
|
| 331 |
-
base_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
|
| 332 |
-
|
| 333 |
-
# 4. LoRA config
|
| 334 |
-
lora_config = LoraConfig(
|
| 335 |
-
task_type=TaskType.SEQ_CLS,
|
| 336 |
-
r=8, lora_alpha=16, lora_dropout=0.1, target_modules=["q_lin", "v_lin"]
|
| 337 |
-
)
|
| 338 |
-
|
| 339 |
-
model = get_peft_model(base_model, lora_config)
|
| 340 |
-
|
| 341 |
-
# 5. Training setup
|
| 342 |
-
import os
|
| 343 |
-
os.environ["WANDB_DISABLED"] = "true"
|
| 344 |
-
|
| 345 |
-
args = TrainingArguments(
|
| 346 |
-
output_dir="./lora-green-patents",
|
| 347 |
-
do_eval=True, # instead of evaluation_strategy
|
| 348 |
-
eval_steps=500, # run eval every N steps
|
| 349 |
-
save_steps=500, # save checkpoint every N steps
|
| 350 |
-
learning_rate=2e-4,
|
| 351 |
-
per_device_train_batch_size=8,
|
| 352 |
-
per_device_eval_batch_size=8,
|
| 353 |
-
num_train_epochs=10,
|
| 354 |
-
logging_steps=20,
|
| 355 |
-
report_to=None
|
| 356 |
-
)
|
| 357 |
-
|
| 358 |
-
trainer = Trainer(
|
| 359 |
-
model=model,
|
| 360 |
-
args=args,
|
| 361 |
-
train_dataset=train_dataset,
|
| 362 |
-
eval_dataset=test_dataset,
|
| 363 |
-
)
|
| 364 |
-
|
| 365 |
-
# 6. Train
|
| 366 |
-
trainer.train()
|
| 367 |
-
|
| 368 |
-
# 7. Save adapter
|
| 369 |
-
model.save_pretrained("lora-green-patents")
|
| 370 |
-
tokenizer.save_pretrained("lora-green-patents")
|
| 371 |
-
|
| 372 |
-
# 8. Inference
|
| 373 |
-
|
| 374 |
# Load base + adapter
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
|
| 380 |
-
|
| 381 |
-
texts = [
|
| 382 |
-
"A biodegradable plastic composition derived from renewable corn starch.",
|
| 383 |
-
"A new synthetic polymer with enhanced tensile strength."
|
| 384 |
-
"Refreshing Taste: Every bottle of Pure Life Water is enhanced with minerals for a crisp taste that makes drinking water delicious. 12 pack of 16.9 fl oz water bottles."
|
| 385 |
-
"This 18/8 stainless steel water bottle is designed to last a lifetime. Plastic free & Eco friendly water bottles are a healthier option for you & the planet! However, Water in stainless steel tastes different than plastic, make sure your taste buds are ready for this healthy switch"
|
| 386 |
-
]
|
| 387 |
-
print(clf(texts))
|
| 388 |
|
| 389 |
ex_waterbottle_text = [
|
| 390 |
"A single use case made with fossil fuels and gasoline.",
|
|
|
|
| 305 |
# Example
|
| 306 |
print("Example: ", query("water bottle", total_examples=4, near_far_ratio=0.5))
|
| 307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
# Load base + adapter
|
| 309 |
+
def lora_load():
|
| 310 |
+
model_name = "distilbert-base-uncased" # same base you trained on
|
| 311 |
+
|
| 312 |
+
tokenizer = AutoTokenizer.from_pretrained(REPO_ID_LORA_GREEN_PATENTS) # , token=token)
|
| 313 |
+
base_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2) # , token=token)
|
| 314 |
+
model = PeftModel.from_pretrained(base_model, REPO_ID_LORA_GREEN_PATENTS) # , token=token)
|
| 315 |
+
|
| 316 |
+
clf = pipeline("text-classification", model=model, tokenizer=tokenizer)
|
| 317 |
+
|
| 318 |
+
# Examples of patents and products (fixed commas)
|
| 319 |
+
texts = [
|
| 320 |
+
"A biodegradable plastic composition derived from renewable corn starch.",
|
| 321 |
+
"A new synthetic polymer with enhanced tensile strength.",
|
| 322 |
+
"Refreshing Taste: Every bottle of Pure Life Water is enhanced with minerals for a crisp taste that makes drinking water delicious. 12 pack of 16.9 fl oz water bottles.",
|
| 323 |
+
"This 18/8 stainless steel water bottle is designed to last a lifetime. Plastic free & Eco friendly water bottles are a healthier option for you & the planet! However, Water in stainless steel tastes different than plastic, make sure your taste buds are ready for this healthy switch"
|
| 324 |
+
]
|
| 325 |
+
print(clf(texts))
|
| 326 |
+
return clf
|
| 327 |
|
| 328 |
+
clf = lora_load()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
ex_waterbottle_text = [
|
| 331 |
"A single use case made with fossil fuels and gasoline.",
|