Spaces:
Running
Running
File size: 12,137 Bytes
06c3276 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 |
import asyncio
import os
import re
import subprocess
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
import tempfile
from typing import Dict, Optional, List, Any, Set
import hashlib
import requests
import aiohttp
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from graphgen.bases import BaseSearcher
from graphgen.utils import logger
@lru_cache(maxsize=None)
def _get_pool():
return ThreadPoolExecutor(max_workers=10)
class RNACentralSearch(BaseSearcher):
"""
RNAcentral Search client to search RNA databases.
1) Get RNA by RNAcentral ID.
2) Search with keywords or RNA names (fuzzy search).
3) Search with RNA sequence.
API Documentation: https://rnacentral.org/api/v1
"""
def __init__(self, use_local_blast: bool = False, local_blast_db: str = "rna_db"):
super().__init__()
self.base_url = "https://rnacentral.org/api/v1"
self.headers = {"Accept": "application/json"}
self.use_local_blast = use_local_blast
self.local_blast_db = local_blast_db
if self.use_local_blast and not os.path.isfile(f"{self.local_blast_db}.nhr"):
logger.error("Local BLAST database files not found. Please check the path.")
self.use_local_blast = False
@staticmethod
def _rna_data_to_dict(
rna_id: str,
rna_data: Dict[str, Any],
xrefs_data: Optional[List[Dict[str, Any]]] = None
) -> Dict[str, Any]:
organisms, gene_names, so_terms = set(), set(), set()
modifications: List[Any] = []
for xref in xrefs_data or []:
acc = xref.get("accession", {})
if s := acc.get("species"):
organisms.add(s)
if g := acc.get("gene", "").strip():
gene_names.add(g)
if m := xref.get("modifications"):
modifications.extend(m)
if b := acc.get("biotype"):
so_terms.add(b)
def format_unique_values(values: Set[str]) -> Optional[str]:
if not values:
return None
if len(values) == 1:
return next(iter(values))
return ", ".join(sorted(values))
xrefs_info = {
"organism": format_unique_values(organisms),
"gene_name": format_unique_values(gene_names),
"related_genes": list(gene_names) if gene_names else None,
"modifications": modifications or None,
"so_term": format_unique_values(so_terms),
}
fallback_rules = {
"organism": ["organism", "species"],
"related_genes": ["related_genes", "genes"],
"gene_name": ["gene_name", "gene"],
"so_term": ["so_term"],
"modifications": ["modifications"],
}
def resolve_field(field_name: str) -> Any:
if (value := xrefs_info.get(field_name)) is not None:
return value
for key in fallback_rules[field_name]:
if (value := rna_data.get(key)) is not None:
return value
return None
organism = resolve_field("organism")
gene_name = resolve_field("gene_name")
so_term = resolve_field("so_term")
modifications = resolve_field("modifications")
related_genes = resolve_field("related_genes")
if not related_genes and (single_gene := rna_data.get("gene_name")):
related_genes = [single_gene]
sequence = rna_data.get("sequence", "")
return {
"molecule_type": "RNA",
"database": "RNAcentral",
"id": rna_id,
"rnacentral_id": rna_data.get("rnacentral_id", rna_id),
"sequence": sequence,
"sequence_length": rna_data.get("length", len(sequence)),
"rna_type": rna_data.get("rna_type", "N/A"),
"description": rna_data.get("description", "N/A"),
"url": f"https://rnacentral.org/rna/{rna_id}",
"organism": organism,
"related_genes": related_genes or None,
"gene_name": gene_name,
"so_term": so_term,
"modifications": modifications,
}
@staticmethod
def _calculate_md5(sequence: str) -> str:
"""
Calculate MD5 hash for RNA sequence as per RNAcentral spec.
- Replace U with T
- Convert to uppercase
- Encode as ASCII
"""
# Normalize sequence
normalized_seq = sequence.replace("U", "T").replace("u", "t").upper()
if not re.fullmatch(r"[ATCGN]+", normalized_seq):
raise ValueError(f"Invalid sequence characters after normalization: {normalized_seq[:50]}...")
return hashlib.md5(normalized_seq.encode("ascii")).hexdigest()
def get_by_rna_id(self, rna_id: str) -> Optional[dict]:
"""
Get RNA information by RNAcentral ID.
:param rna_id: RNAcentral ID (e.g., URS0000000001).
:return: A dictionary containing RNA information or None if not found.
"""
try:
url = f"{self.base_url}/rna/{rna_id}"
url += "?flat=true"
resp = requests.get(url, headers=self.headers, timeout=30)
resp.raise_for_status()
rna_data = resp.json()
xrefs_data = rna_data.get("xrefs", [])
return self._rna_data_to_dict(rna_id, rna_data, xrefs_data)
except requests.RequestException as e:
logger.error("Network error getting RNA ID %s: %s", rna_id, e)
return None
except Exception as e: # pylint: disable=broad-except
logger.error("Unexpected error getting RNA ID %s: %s", rna_id, e)
return None
def get_best_hit(self, keyword: str) -> Optional[dict]:
"""
Search RNAcentral with a keyword and return the best hit.
:param keyword: The search keyword (e.g., miRNA name, RNA name).
:return: Dictionary with RNA information or None.
"""
keyword = keyword.strip()
if not keyword:
logger.warning("Empty keyword provided to get_best_hit")
return None
try:
url = f"{self.base_url}/rna"
params = {"search": keyword, "format": "json"}
resp = requests.get(url, params=params, headers=self.headers, timeout=30)
resp.raise_for_status()
data = resp.json()
results = data.get("results", [])
if not results:
logger.info("No search results for keyword: %s", keyword)
return None
first_result = results[0]
rna_id = first_result.get("rnacentral_id")
if rna_id:
detailed = self.get_by_rna_id(rna_id)
if detailed:
return detailed
logger.debug("Using search result data for %s", rna_id or "unknown")
return self._rna_data_to_dict(rna_id or "", first_result)
except requests.RequestException as e:
logger.error("Network error searching keyword '%s': %s", keyword, e)
return None
except Exception as e:
logger.error("Unexpected error searching keyword '%s': %s", keyword, e)
return None
def _local_blast(self, seq: str, threshold: float) -> Optional[str]:
"""Perform local BLAST search using local BLAST database."""
try:
with tempfile.NamedTemporaryFile(mode="w+", suffix=".fa", delete=False) as tmp:
tmp.write(f">query\n{seq}\n")
tmp_name = tmp.name
cmd = [
"blastn", "-db", self.local_blast_db, "-query", tmp_name,
"-evalue", str(threshold), "-max_target_seqs", "1", "-outfmt", "6 sacc"
]
logger.debug("Running local blastn for RNA: %s", " ".join(cmd))
out = subprocess.check_output(cmd, text=True).strip()
os.remove(tmp_name)
return out.split("\n", maxsplit=1)[0] if out else None
except Exception as exc:
logger.error("Local blastn failed: %s", exc)
return None
def get_by_fasta(self, sequence: str, threshold: float = 0.01) -> Optional[dict]:
"""
Search RNAcentral with an RNA sequence.
Tries local BLAST first if enabled, falls back to RNAcentral API.
Unified approach: Find RNA ID from sequence search, then call get_by_rna_id() for complete information.
:param sequence: RNA sequence (FASTA format or raw sequence).
:param threshold: E-value threshold for BLAST search.
:return: A dictionary containing complete RNA information or None if not found.
"""
def _extract_sequence(sequence: str) -> Optional[str]:
"""Extract and normalize RNA sequence from input."""
if sequence.startswith(">"):
seq_lines = sequence.strip().split("\n")
seq = "".join(seq_lines[1:])
else:
seq = sequence.strip().replace(" ", "").replace("\n", "")
return seq if seq and re.fullmatch(r"[AUCGN\s]+", seq, re.I) else None
try:
seq = _extract_sequence(sequence)
if not seq:
logger.error("Empty or invalid RNA sequence provided.")
return None
# Try local BLAST first if enabled
if self.use_local_blast:
accession = self._local_blast(seq, threshold)
if accession:
logger.debug("Local BLAST found accession: %s", accession)
return self.get_by_rna_id(accession)
# Fall back to RNAcentral API if local BLAST didn't find result
logger.debug("Falling back to RNAcentral API.")
md5_hash = self._calculate_md5(seq)
search_url = f"{self.base_url}/rna"
params = {"md5": md5_hash, "format": "json"}
resp = requests.get(search_url, params=params, headers=self.headers, timeout=60)
resp.raise_for_status()
search_results = resp.json()
results = search_results.get("results", [])
if not results:
logger.info("No exact match found in RNAcentral for sequence")
return None
rna_id = results[0].get("rnacentral_id")
if not rna_id:
logger.error("No RNAcentral ID found in search results.")
return None
return self.get_by_rna_id(rna_id)
except Exception as e:
logger.error("Sequence search failed: %s", e)
return None
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=10),
retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError)),
reraise=True,
)
async def search(self, query: str, threshold: float = 0.1, **kwargs) -> Optional[Dict]:
"""Search RNAcentral with either an RNAcentral ID, keyword, or RNA sequence."""
if not query or not isinstance(query, str):
logger.error("Empty or non-string input.")
return None
query = query.strip()
logger.debug("RNAcentral search query: %s", query)
loop = asyncio.get_running_loop()
# check if RNA sequence (AUCG characters, contains U)
if query.startswith(">") or (
re.fullmatch(r"[AUCGN\s]+", query, re.I) and "U" in query.upper()
):
result = await loop.run_in_executor(_get_pool(), self.get_by_fasta, query, threshold)
# check if RNAcentral ID (typically starts with URS)
elif re.fullmatch(r"URS\d+", query, re.I):
result = await loop.run_in_executor(_get_pool(), self.get_by_rna_id, query)
else:
# otherwise treat as keyword
result = await loop.run_in_executor(_get_pool(), self.get_best_hit, query)
if result:
result["_search_query"] = query
return result
|