Spaces:
Runtime error
Runtime error
| import copy | |
| import logging | |
| from pydantic import BaseModel, Field | |
| from hugginggpt.exceptions import TaskParsingException, wrap_exceptions | |
| logger = logging.getLogger(__name__) | |
| GENERATED_TOKEN = "<GENERATED>" | |
| class Task(BaseModel): | |
| # This field is called 'task' and not 'name' to help with prompt engineering | |
| task: str = Field(description="Name of the Machine Learning task") | |
| id: int = Field(description="ID of the task") | |
| dep: list[int] = Field( | |
| description="List of IDs of the tasks that this task depends on" | |
| ) | |
| args: dict[str, str] = Field(description="Arguments for the task") | |
| def depends_on_generated_resources(self) -> bool: | |
| """Returns True if the task args contains <GENERATED> placeholder tokens, False otherwise""" | |
| return self.dep != [-1] and any( | |
| GENERATED_TOKEN in v for v in self.args.values() | |
| ) | |
| def replace_generated_resources(self, task_summaries: list): | |
| """Replaces <GENERATED> placeholder tokens in args with the generated resources from the task summaries""" | |
| logger.info("Replacing generated resources") | |
| generated_resources = { | |
| k: parse_task_id(v) for k, v in self.args.items() if GENERATED_TOKEN in v | |
| } | |
| logger.info( | |
| f"Resources to replace, resource type -> task id: {generated_resources}" | |
| ) | |
| for resource_type, task_id in generated_resources.items(): | |
| matches = [ | |
| v | |
| for k, v in task_summaries[task_id].inference_result.items() | |
| if self.is_matching_generated_resource(k, resource_type) | |
| ] | |
| if len(matches) == 1: | |
| logger.info( | |
| f"Match for generated {resource_type} in inference result of task {task_id}" | |
| ) | |
| generated_resource = matches[0] | |
| logger.info(f"Replacing {resource_type} with {generated_resource}") | |
| self.args[resource_type] = generated_resource | |
| return self | |
| else: | |
| raise Exception( | |
| f"Cannot find unique required generated {resource_type} in inference result of task {task_id}" | |
| ) | |
| def is_matching_generated_resource(self, arg_key: str, resource_type: str) -> bool: | |
| """Returns True if arg_key contains generated resource of the correct type""" | |
| # If text, then match all arg keys that contain "text" | |
| if resource_type.startswith("text"): | |
| return "text" in arg_key | |
| # If not text, then arg key must start with "generated" and the correct resource type | |
| else: | |
| return arg_key.startswith("generated " + resource_type) | |
| class Tasks(BaseModel): | |
| __root__: list[Task] = Field(description="List of Machine Learning tasks") | |
| def __iter__(self): | |
| return iter(self.__root__) | |
| def __getitem__(self, item): | |
| return self.__root__[item] | |
| def __len__(self): | |
| return len(self.__root__) | |
| def parse_tasks(tasks_str: str) -> list[Task]: | |
| """Parses tasks from task planning json string""" | |
| if tasks_str == "[]": | |
| raise ValueError("Task string empty, cannot parse") | |
| logger.info(f"Parsing tasks string: {tasks_str}") | |
| tasks_str = tasks_str.strip() | |
| # Cannot use PydanticOutputParser because it fails when parsing top level list JSON string | |
| tasks = Tasks.parse_raw(tasks_str) | |
| # __root__ extracts list[Task] from Tasks object | |
| tasks = unfold(tasks.__root__) | |
| tasks = fix_dependencies(tasks) | |
| logger.info(f"Parsed tasks: {tasks}") | |
| return tasks | |
| def parse_task_id(resource_str: str) -> int: | |
| """Parse task id from generated resource string, e.g. <GENERATED>-4 -> 4""" | |
| return int(resource_str.split("-")[1]) | |
| def fix_dependencies(tasks: list[Task]) -> list[Task]: | |
| """Ignores parsed tasks dependencies, and instead infers from task arguments""" | |
| for task in tasks: | |
| task.dep = infer_deps_from_args(task) | |
| return tasks | |
| def infer_deps_from_args(task: Task) -> list[int]: | |
| """If GENERATED arg value, add to list of unique deps. If none, deps = [-1]""" | |
| deps = [parse_task_id(v) for v in task.args.values() if GENERATED_TOKEN in v] | |
| if not deps: | |
| deps = [-1] | |
| # deduplicate | |
| return list(set(deps)) | |
| def unfold(tasks: list[Task]) -> list[Task]: | |
| """A folded task has several generated resources folded into a single argument""" | |
| unfolded_tasks = [] | |
| for task in tasks: | |
| folded_args = find_folded_args(task) | |
| if folded_args: | |
| unfolded_tasks.extend(split(task, folded_args)) | |
| else: | |
| unfolded_tasks.append(task) | |
| return unfolded_tasks | |
| def split(task: Task, folded_args: tuple[str, str]) -> list[Task]: | |
| """Split folded task into two same tasks, but separated generated resource arguments""" | |
| key, value = folded_args | |
| generated_items = value.split(",") | |
| split_tasks = [] | |
| for item in generated_items: | |
| new_task = copy.deepcopy(task) | |
| dep_task_id = parse_task_id(item) | |
| new_task.dep = [dep_task_id] | |
| new_task.args[key] = item.strip() | |
| split_tasks.append(new_task) | |
| return split_tasks | |
| def find_folded_args(task: Task) -> tuple[str, str] | None: | |
| """Finds folded args, e.g: 'image': '<GENERATED>-1,<GENERATED>-2'""" | |
| for key, value in task.args.items(): | |
| if value.count(GENERATED_TOKEN) > 1: | |
| logger.debug(f"Task {task.id} is folded") | |
| return key, value | |
| return None | |