File size: 5,620 Bytes
f825a37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""
Module de génération de réponses pour le service client Amazon
Utilise CroissantLLMChat - Modèle bilingue français-anglais optimisé
"""

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Modèle bilingue français-anglais spécialement conçu pour le français
MODEL_NAME = "croissantllm/CroissantLLMChat-v0.1"

# Variables globales pour le modèle
model = None
tokenizer = None

def load_model():
    """
    Charge le modèle CroissantLLMChat et son tokenizer
    CroissantLLM est un modèle 1.3B VRAIMENT bilingue (50% FR / 50% EN)
    
    Returns:
        tuple: (model, tokenizer) chargés
    """
    global model, tokenizer
    
    print(f"🔄 Chargement du modèle {MODEL_NAME}...")
    print("⏳ CroissantLLM est un modèle français de 1.3B paramètres (~2-3 GB)")
    
    # Charger le tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    
    # Charger le modèle en float32 pour CPU
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float32,
        device_map="cpu",
        low_cpu_mem_usage=True
    )
    
    print("✅ Modèle CroissantLLMChat chargé avec succès !")
    print("🥐 Modèle français bilingue prêt !")
    
    return model, tokenizer

def build_chat_messages(review_text: str) -> list:
    """
    Construit les messages pour CroissantLLMChat
    Format officiel avec apply_chat_template
    
    Args:
        review_text (str): Texte de l'avis client négatif
        
    Returns:
        list: Messages formatés pour CroissantLLMChat
    """
    # CroissantLLMChat utilise un format chat officiel
    # Avec un message utilisateur clair
    chat_messages = [
        {
            "role": "user", 
            "content": f"""Tu es un agent du service client Amazon. Réponds en français à cet avis négatif avec empathie et professionnalisme :

"{review_text}"

Réponds en présentant des excuses, en reconnaissant le problème, et en proposant une solution concrète (remboursement ou échange)."""
        }
    ]
    
    return chat_messages

def generer_reponse(review_text: str, max_tokens: int = 120, temperature: float = 0.7) -> str:
    """
    Génère une réponse au service client pour un avis négatif
    Utilise CroissantLLMChat avec apply_chat_template (méthode officielle)
    
    Args:
        review_text (str): Texte de l'avis client négatif
        max_tokens (int): Nombre maximum de tokens à générer
        temperature (float): Température de génération (0.7 = équilibré)
        
    Returns:
        str: Réponse générée par le modèle EN FRANÇAIS
    """
    global model, tokenizer
    
    # Charger le modèle si pas encore fait
    if model is None or tokenizer is None:
        load_model()
    
    # Construire les messages au format chat
    chat_messages = build_chat_messages(review_text)
    
    # Appliquer le template officiel de CroissantLLMChat
    # C'est la méthode recommandée dans la documentation
    chat_input = tokenizer.apply_chat_template(
        chat_messages, 
        tokenize=False, 
        add_generation_prompt=True
    )
    
    # Tokeniser le chat formaté
    inputs = tokenizer(
        chat_input,
        return_tensors="pt",
        max_length=512,
        truncation=True
    )
    
    # Générer avec CroissantLLMChat
    # Température 0.7 recommandée (doc dit 0.3+ minimum)
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            attention_mask=inputs.attention_mask,
            max_new_tokens=max_tokens,
            temperature=temperature,
            do_sample=True,
            top_p=0.9,
            top_k=50,
            repetition_penalty=1.2,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    
    # MÉTHODE AMÉLIORÉE : Décoder UNIQUEMENT les nouveaux tokens
    # On ne décode PAS le prompt d'entrée
    input_length = inputs.input_ids.shape[1]
    generated_tokens = outputs[0][input_length:]  # Prendre uniquement les tokens générés
    
    # Décoder uniquement la réponse générée
    answer = tokenizer.decode(generated_tokens, skip_special_tokens=True)
    
    # Nettoyer les tokens spéciaux qui pourraient rester
    special_tokens = ["<|im_start|>", "<|im_end|>", "assistant", "user", "system"]
    for token in special_tokens:
        answer = answer.replace(token, "")
    
    # Nettoyer espaces multiples
    answer = ' '.join(answer.split())
    answer = answer.strip()
    
    # Limiter à 3-4 phrases maximum
    sentences = answer.split('.')
    clean_sentences = [s.strip() for s in sentences if s.strip()]
    if len(clean_sentences) > 4:
        answer = '. '.join(clean_sentences[:4]) + '.'
    else:
        answer = '. '.join(clean_sentences)
        if not answer.endswith('.'):
            answer += '.'
    
    return answer

# Test du module
if __name__ == "__main__":
    print("🧪 Test du module de génération avec CroissantLLMChat\n")
    
    # Charger le modèle
    load_model()
    
    # Test 1
    avis_test_1 = "Le produit est arrivé cassé et le service client ne répond pas. Très déçu !"
    print(f"📝 Avis test 1: {avis_test_1}")
    reponse_1 = generer_reponse(avis_test_1)
    print(f"💬 Réponse: {reponse_1}\n")
    
    # Test 2
    avis_test_2 = "Livraison en retard de 2 semaines, produit endommagé."
    print(f"📝 Avis test 2: {avis_test_2}")
    reponse_2 = generer_reponse(avis_test_2)
    print(f"💬 Réponse: {reponse_2}\n")
    
    print("✅ Tests terminés !")