amazon-sentiment-analysis-shirin / generate_response.py
Oxyb50410's picture
Upload 4 files
f825a37 verified
"""
Module de génération de réponses pour le service client Amazon
Utilise CroissantLLMChat - Modèle bilingue français-anglais optimisé
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Modèle bilingue français-anglais spécialement conçu pour le français
MODEL_NAME = "croissantllm/CroissantLLMChat-v0.1"
# Variables globales pour le modèle
model = None
tokenizer = None
def load_model():
"""
Charge le modèle CroissantLLMChat et son tokenizer
CroissantLLM est un modèle 1.3B VRAIMENT bilingue (50% FR / 50% EN)
Returns:
tuple: (model, tokenizer) chargés
"""
global model, tokenizer
print(f"🔄 Chargement du modèle {MODEL_NAME}...")
print("⏳ CroissantLLM est un modèle français de 1.3B paramètres (~2-3 GB)")
# Charger le tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Charger le modèle en float32 pour CPU
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32,
device_map="cpu",
low_cpu_mem_usage=True
)
print("✅ Modèle CroissantLLMChat chargé avec succès !")
print("🥐 Modèle français bilingue prêt !")
return model, tokenizer
def build_chat_messages(review_text: str) -> list:
"""
Construit les messages pour CroissantLLMChat
Format officiel avec apply_chat_template
Args:
review_text (str): Texte de l'avis client négatif
Returns:
list: Messages formatés pour CroissantLLMChat
"""
# CroissantLLMChat utilise un format chat officiel
# Avec un message utilisateur clair
chat_messages = [
{
"role": "user",
"content": f"""Tu es un agent du service client Amazon. Réponds en français à cet avis négatif avec empathie et professionnalisme :
"{review_text}"
Réponds en présentant des excuses, en reconnaissant le problème, et en proposant une solution concrète (remboursement ou échange)."""
}
]
return chat_messages
def generer_reponse(review_text: str, max_tokens: int = 120, temperature: float = 0.7) -> str:
"""
Génère une réponse au service client pour un avis négatif
Utilise CroissantLLMChat avec apply_chat_template (méthode officielle)
Args:
review_text (str): Texte de l'avis client négatif
max_tokens (int): Nombre maximum de tokens à générer
temperature (float): Température de génération (0.7 = équilibré)
Returns:
str: Réponse générée par le modèle EN FRANÇAIS
"""
global model, tokenizer
# Charger le modèle si pas encore fait
if model is None or tokenizer is None:
load_model()
# Construire les messages au format chat
chat_messages = build_chat_messages(review_text)
# Appliquer le template officiel de CroissantLLMChat
# C'est la méthode recommandée dans la documentation
chat_input = tokenizer.apply_chat_template(
chat_messages,
tokenize=False,
add_generation_prompt=True
)
# Tokeniser le chat formaté
inputs = tokenizer(
chat_input,
return_tensors="pt",
max_length=512,
truncation=True
)
# Générer avec CroissantLLMChat
# Température 0.7 recommandée (doc dit 0.3+ minimum)
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
top_p=0.9,
top_k=50,
repetition_penalty=1.2,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
# MÉTHODE AMÉLIORÉE : Décoder UNIQUEMENT les nouveaux tokens
# On ne décode PAS le prompt d'entrée
input_length = inputs.input_ids.shape[1]
generated_tokens = outputs[0][input_length:] # Prendre uniquement les tokens générés
# Décoder uniquement la réponse générée
answer = tokenizer.decode(generated_tokens, skip_special_tokens=True)
# Nettoyer les tokens spéciaux qui pourraient rester
special_tokens = ["<|im_start|>", "<|im_end|>", "assistant", "user", "system"]
for token in special_tokens:
answer = answer.replace(token, "")
# Nettoyer espaces multiples
answer = ' '.join(answer.split())
answer = answer.strip()
# Limiter à 3-4 phrases maximum
sentences = answer.split('.')
clean_sentences = [s.strip() for s in sentences if s.strip()]
if len(clean_sentences) > 4:
answer = '. '.join(clean_sentences[:4]) + '.'
else:
answer = '. '.join(clean_sentences)
if not answer.endswith('.'):
answer += '.'
return answer
# Test du module
if __name__ == "__main__":
print("🧪 Test du module de génération avec CroissantLLMChat\n")
# Charger le modèle
load_model()
# Test 1
avis_test_1 = "Le produit est arrivé cassé et le service client ne répond pas. Très déçu !"
print(f"📝 Avis test 1: {avis_test_1}")
reponse_1 = generer_reponse(avis_test_1)
print(f"💬 Réponse: {reponse_1}\n")
# Test 2
avis_test_2 = "Livraison en retard de 2 semaines, produit endommagé."
print(f"📝 Avis test 2: {avis_test_2}")
reponse_2 = generer_reponse(avis_test_2)
print(f"💬 Réponse: {reponse_2}\n")
print("✅ Tests terminés !")