Spaces:
Runtime error
Runtime error
File size: 5,620 Bytes
f825a37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
"""
Module de génération de réponses pour le service client Amazon
Utilise CroissantLLMChat - Modèle bilingue français-anglais optimisé
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Modèle bilingue français-anglais spécialement conçu pour le français
MODEL_NAME = "croissantllm/CroissantLLMChat-v0.1"
# Variables globales pour le modèle
model = None
tokenizer = None
def load_model():
"""
Charge le modèle CroissantLLMChat et son tokenizer
CroissantLLM est un modèle 1.3B VRAIMENT bilingue (50% FR / 50% EN)
Returns:
tuple: (model, tokenizer) chargés
"""
global model, tokenizer
print(f"🔄 Chargement du modèle {MODEL_NAME}...")
print("⏳ CroissantLLM est un modèle français de 1.3B paramètres (~2-3 GB)")
# Charger le tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Charger le modèle en float32 pour CPU
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32,
device_map="cpu",
low_cpu_mem_usage=True
)
print("✅ Modèle CroissantLLMChat chargé avec succès !")
print("🥐 Modèle français bilingue prêt !")
return model, tokenizer
def build_chat_messages(review_text: str) -> list:
"""
Construit les messages pour CroissantLLMChat
Format officiel avec apply_chat_template
Args:
review_text (str): Texte de l'avis client négatif
Returns:
list: Messages formatés pour CroissantLLMChat
"""
# CroissantLLMChat utilise un format chat officiel
# Avec un message utilisateur clair
chat_messages = [
{
"role": "user",
"content": f"""Tu es un agent du service client Amazon. Réponds en français à cet avis négatif avec empathie et professionnalisme :
"{review_text}"
Réponds en présentant des excuses, en reconnaissant le problème, et en proposant une solution concrète (remboursement ou échange)."""
}
]
return chat_messages
def generer_reponse(review_text: str, max_tokens: int = 120, temperature: float = 0.7) -> str:
"""
Génère une réponse au service client pour un avis négatif
Utilise CroissantLLMChat avec apply_chat_template (méthode officielle)
Args:
review_text (str): Texte de l'avis client négatif
max_tokens (int): Nombre maximum de tokens à générer
temperature (float): Température de génération (0.7 = équilibré)
Returns:
str: Réponse générée par le modèle EN FRANÇAIS
"""
global model, tokenizer
# Charger le modèle si pas encore fait
if model is None or tokenizer is None:
load_model()
# Construire les messages au format chat
chat_messages = build_chat_messages(review_text)
# Appliquer le template officiel de CroissantLLMChat
# C'est la méthode recommandée dans la documentation
chat_input = tokenizer.apply_chat_template(
chat_messages,
tokenize=False,
add_generation_prompt=True
)
# Tokeniser le chat formaté
inputs = tokenizer(
chat_input,
return_tensors="pt",
max_length=512,
truncation=True
)
# Générer avec CroissantLLMChat
# Température 0.7 recommandée (doc dit 0.3+ minimum)
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
top_p=0.9,
top_k=50,
repetition_penalty=1.2,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
# MÉTHODE AMÉLIORÉE : Décoder UNIQUEMENT les nouveaux tokens
# On ne décode PAS le prompt d'entrée
input_length = inputs.input_ids.shape[1]
generated_tokens = outputs[0][input_length:] # Prendre uniquement les tokens générés
# Décoder uniquement la réponse générée
answer = tokenizer.decode(generated_tokens, skip_special_tokens=True)
# Nettoyer les tokens spéciaux qui pourraient rester
special_tokens = ["<|im_start|>", "<|im_end|>", "assistant", "user", "system"]
for token in special_tokens:
answer = answer.replace(token, "")
# Nettoyer espaces multiples
answer = ' '.join(answer.split())
answer = answer.strip()
# Limiter à 3-4 phrases maximum
sentences = answer.split('.')
clean_sentences = [s.strip() for s in sentences if s.strip()]
if len(clean_sentences) > 4:
answer = '. '.join(clean_sentences[:4]) + '.'
else:
answer = '. '.join(clean_sentences)
if not answer.endswith('.'):
answer += '.'
return answer
# Test du module
if __name__ == "__main__":
print("🧪 Test du module de génération avec CroissantLLMChat\n")
# Charger le modèle
load_model()
# Test 1
avis_test_1 = "Le produit est arrivé cassé et le service client ne répond pas. Très déçu !"
print(f"📝 Avis test 1: {avis_test_1}")
reponse_1 = generer_reponse(avis_test_1)
print(f"💬 Réponse: {reponse_1}\n")
# Test 2
avis_test_2 = "Livraison en retard de 2 semaines, produit endommagé."
print(f"📝 Avis test 2: {avis_test_2}")
reponse_2 = generer_reponse(avis_test_2)
print(f"💬 Réponse: {reponse_2}\n")
print("✅ Tests terminés !")
|