Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -220,12 +220,12 @@ import tensorflow as tf
|
|
| 220 |
import praw
|
| 221 |
import os
|
| 222 |
|
| 223 |
-
|
| 224 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 225 |
import torch
|
| 226 |
from scipy.special import softmax
|
| 227 |
|
| 228 |
-
|
| 229 |
model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
|
| 230 |
tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
|
| 231 |
|
|
@@ -235,7 +235,7 @@ LABELS = {
|
|
| 235 |
2: "Negative"
|
| 236 |
}
|
| 237 |
|
| 238 |
-
|
| 239 |
fallback_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
|
| 240 |
fallback_tokenizer = AutoTokenizer.from_pretrained(fallback_model_name)
|
| 241 |
fallback_model = AutoModelForSequenceClassification.from_pretrained(fallback_model_name)
|
|
@@ -254,7 +254,7 @@ def fetch_reddit_text(reddit_url):
|
|
| 254 |
except Exception as e:
|
| 255 |
return f"Error fetching Reddit post: {str(e)}"
|
| 256 |
|
| 257 |
-
|
| 258 |
def fallback_classifier(text):
|
| 259 |
encoded_input = fallback_tokenizer(text, return_tensors='pt', truncation=True, padding=True)
|
| 260 |
with torch.no_grad():
|
|
|
|
| 220 |
import praw
|
| 221 |
import os
|
| 222 |
|
| 223 |
+
|
| 224 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 225 |
import torch
|
| 226 |
from scipy.special import softmax
|
| 227 |
|
| 228 |
+
|
| 229 |
model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
|
| 230 |
tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
|
| 231 |
|
|
|
|
| 235 |
2: "Negative"
|
| 236 |
}
|
| 237 |
|
| 238 |
+
|
| 239 |
fallback_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
|
| 240 |
fallback_tokenizer = AutoTokenizer.from_pretrained(fallback_model_name)
|
| 241 |
fallback_model = AutoModelForSequenceClassification.from_pretrained(fallback_model_name)
|
|
|
|
| 254 |
except Exception as e:
|
| 255 |
return f"Error fetching Reddit post: {str(e)}"
|
| 256 |
|
| 257 |
+
|
| 258 |
def fallback_classifier(text):
|
| 259 |
encoded_input = fallback_tokenizer(text, return_tensors='pt', truncation=True, padding=True)
|
| 260 |
with torch.no_grad():
|