Spaces:
Paused
Paused
| import sys | |
| import json | |
| import hashlib | |
| import torch | |
| from typing import List | |
| def get_md5(input_str): | |
| # Create an MD5 hash object | |
| md5_hash = hashlib.md5() | |
| # Encode the string and update the hash object | |
| md5_hash.update(input_str.encode('utf-8')) | |
| # Return the hexadecimal MD5 digest | |
| return md5_hash.hexdigest() | |
| def tool_result_format(function_call_messages): | |
| current_output = "\n\n<details>\n<summary> <strong>Verfied Feedback from Tools</strong>, click to see details:</summary>\n\n" | |
| for each_message in function_call_messages: | |
| if each_message['role'] == 'tool': | |
| current_output += f"{each_message['content']}\n\n" | |
| current_output += "</details>\n\n\n" | |
| return current_output | |
| class NoRepeatSentenceProcessor: | |
| def __init__(self, forbidden_sequences: List[List[int]], allowed_prefix_length: int): | |
| """ | |
| Args: | |
| forbidden_sequences (List[List[int]]): A list of token ID sequences corresponding to forbidden sentences. | |
| allowed_prefix_length (int): The number k such that if the generated tokens match the first k tokens | |
| of a forbidden sequence, then the candidate token that would extend the match is blocked. | |
| """ | |
| self.allowed_prefix_length = allowed_prefix_length | |
| # Build a lookup dictionary: key is a tuple of the first k tokens, value is a set of tokens to block. | |
| self.forbidden_prefix_dict = {} | |
| for seq in forbidden_sequences: | |
| if len(seq) > allowed_prefix_length: | |
| prefix = tuple(seq[:allowed_prefix_length]) | |
| next_token = seq[allowed_prefix_length] | |
| self.forbidden_prefix_dict.setdefault( | |
| prefix, set()).add(next_token) | |
| def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Modifies the logits to block tokens that would extend a forbidden sentence. | |
| Args: | |
| token_ids (List[int]): List of token IDs generated so far. | |
| logits (torch.Tensor): Logits tensor for the next token (shape: [vocab_size]). | |
| Returns: | |
| torch.Tensor: Modified logits. | |
| """ | |
| if len(token_ids) >= self.allowed_prefix_length: | |
| prefix = tuple(token_ids[:self.allowed_prefix_length]) | |
| if prefix in self.forbidden_prefix_dict: | |
| for token_id in self.forbidden_prefix_dict[prefix]: | |
| logits[token_id] = -float("inf") | |
| return logits | |
| class ReasoningTraceChecker: | |
| def __init__(self, question, conversation, init_index=None): | |
| self.question = question | |
| self.conversation = conversation | |
| self.existing_thoughts = [] | |
| self.existing_actions = [] | |
| if init_index is not None: | |
| self.index = init_index | |
| else: | |
| self.index = 1 | |
| self.question = self.question.lower() | |
| self.new_thoughts = [] | |
| self.new_actions = [] | |
| def check_conversation(self): | |
| info = '' | |
| current_index = self.index | |
| for i in range(current_index, len(self.conversation)): | |
| each = self.conversation[i] | |
| self.index = i | |
| if each['role'] == 'assistant': | |
| print(each) | |
| thought = each['content'] | |
| actions = each['tool_calls'] | |
| good_status, current_info = self.check_repeat_thought(thought) | |
| info += current_info | |
| if not good_status: | |
| return False, info | |
| good_status, current_info = self.check_repeat_action(actions) | |
| info += current_info | |
| if not good_status: | |
| return False, info | |
| return True, info | |
| def check_repeat_thought(self, thought): | |
| if thought in self.existing_thoughts: | |
| return False, "repeat_thought" | |
| self.existing_thoughts.append(thought) | |
| return True, '' | |
| def check_repeat_action(self, actions): | |
| if type(actions) != list: | |
| actions = json.loads(actions) | |
| for each_action in actions: | |
| if 'call_id' in each_action: | |
| del each_action['call_id'] | |
| each_action = json.dumps(each_action) | |
| if each_action in self.existing_actions: | |
| return False, "repeat_action" | |
| self.existing_actions.append(each_action) | |
| return True, '' | |