File size: 12,137 Bytes
06c3276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
import asyncio
import os
import re
import subprocess
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
import tempfile
from typing import Dict, Optional, List, Any, Set

import hashlib
import requests
import aiohttp
from tenacity import (
    retry,
    retry_if_exception_type,
    stop_after_attempt,
    wait_exponential,
)

from graphgen.bases import BaseSearcher
from graphgen.utils import logger


@lru_cache(maxsize=None)
def _get_pool():
    return ThreadPoolExecutor(max_workers=10)

class RNACentralSearch(BaseSearcher):
    """
    RNAcentral Search client to search RNA databases.
    1) Get RNA by RNAcentral ID.
    2) Search with keywords or RNA names (fuzzy search).
    3) Search with RNA sequence.

    API Documentation: https://rnacentral.org/api/v1
    """

    def __init__(self, use_local_blast: bool = False, local_blast_db: str = "rna_db"):
        super().__init__()
        self.base_url = "https://rnacentral.org/api/v1"
        self.headers = {"Accept": "application/json"}
        self.use_local_blast = use_local_blast
        self.local_blast_db = local_blast_db
        if self.use_local_blast and not os.path.isfile(f"{self.local_blast_db}.nhr"):
            logger.error("Local BLAST database files not found. Please check the path.")
            self.use_local_blast = False

    @staticmethod
    def _rna_data_to_dict(
        rna_id: str,
        rna_data: Dict[str, Any],
        xrefs_data: Optional[List[Dict[str, Any]]] = None
    ) -> Dict[str, Any]:
        organisms, gene_names, so_terms = set(), set(), set()
        modifications: List[Any] = []

        for xref in xrefs_data or []:
            acc = xref.get("accession", {})
            if s := acc.get("species"):
                organisms.add(s)
            if g := acc.get("gene", "").strip():
                gene_names.add(g)
            if m := xref.get("modifications"):
                modifications.extend(m)
            if b := acc.get("biotype"):
                so_terms.add(b)

        def format_unique_values(values: Set[str]) -> Optional[str]:
            if not values:
                return None
            if len(values) == 1:
                return next(iter(values))
            return ", ".join(sorted(values))

        xrefs_info = {
            "organism": format_unique_values(organisms),
            "gene_name": format_unique_values(gene_names),
            "related_genes": list(gene_names) if gene_names else None,
            "modifications": modifications or None,
            "so_term": format_unique_values(so_terms),
        }

        fallback_rules = {
            "organism": ["organism", "species"],
            "related_genes": ["related_genes", "genes"],
            "gene_name": ["gene_name", "gene"],
            "so_term": ["so_term"],
            "modifications": ["modifications"],
        }

        def resolve_field(field_name: str) -> Any:
            if (value := xrefs_info.get(field_name)) is not None:
                return value

            for key in fallback_rules[field_name]:
                if (value := rna_data.get(key)) is not None:
                    return value

            return None

        organism = resolve_field("organism")
        gene_name = resolve_field("gene_name")
        so_term = resolve_field("so_term")
        modifications = resolve_field("modifications")

        related_genes = resolve_field("related_genes")
        if not related_genes and (single_gene := rna_data.get("gene_name")):
            related_genes = [single_gene]

        sequence = rna_data.get("sequence", "")

        return {
            "molecule_type": "RNA",
            "database": "RNAcentral",
            "id": rna_id,
            "rnacentral_id": rna_data.get("rnacentral_id", rna_id),
            "sequence": sequence,
            "sequence_length": rna_data.get("length", len(sequence)),
            "rna_type": rna_data.get("rna_type", "N/A"),
            "description": rna_data.get("description", "N/A"),
            "url": f"https://rnacentral.org/rna/{rna_id}",
            "organism": organism,
            "related_genes": related_genes or None,
            "gene_name": gene_name,
            "so_term": so_term,
            "modifications": modifications,
        }

    @staticmethod
    def _calculate_md5(sequence: str) -> str:
        """
        Calculate MD5 hash for RNA sequence as per RNAcentral spec.
        - Replace U with T
        - Convert to uppercase
        - Encode as ASCII
        """
        # Normalize sequence
        normalized_seq = sequence.replace("U", "T").replace("u", "t").upper()
        if not re.fullmatch(r"[ATCGN]+", normalized_seq):
            raise ValueError(f"Invalid sequence characters after normalization: {normalized_seq[:50]}...")

        return hashlib.md5(normalized_seq.encode("ascii")).hexdigest()

    def get_by_rna_id(self, rna_id: str) -> Optional[dict]:
        """
        Get RNA information by RNAcentral ID.
        :param rna_id: RNAcentral ID (e.g., URS0000000001).
        :return: A dictionary containing RNA information or None if not found.
        """
        try:
            url = f"{self.base_url}/rna/{rna_id}"
            url += "?flat=true"

            resp = requests.get(url, headers=self.headers, timeout=30)
            resp.raise_for_status()

            rna_data = resp.json()
            xrefs_data = rna_data.get("xrefs", [])
            return self._rna_data_to_dict(rna_id, rna_data, xrefs_data)
        except requests.RequestException as e:
            logger.error("Network error getting RNA ID %s: %s", rna_id, e)
            return None
        except Exception as e:  # pylint: disable=broad-except
            logger.error("Unexpected error getting RNA ID %s: %s", rna_id, e)
            return None

    def get_best_hit(self, keyword: str) -> Optional[dict]:
        """
        Search RNAcentral with a keyword and return the best hit.
        :param keyword: The search keyword (e.g., miRNA name, RNA name).
        :return: Dictionary with RNA information or None.
        """
        keyword = keyword.strip()
        if not keyword:
            logger.warning("Empty keyword provided to get_best_hit")
            return None

        try:
            url = f"{self.base_url}/rna"
            params = {"search": keyword, "format": "json"}
            resp = requests.get(url, params=params, headers=self.headers, timeout=30)
            resp.raise_for_status()

            data = resp.json()
            results = data.get("results", [])

            if not results:
                logger.info("No search results for keyword: %s", keyword)
                return None

            first_result = results[0]
            rna_id = first_result.get("rnacentral_id")

            if rna_id:
                detailed = self.get_by_rna_id(rna_id)
                if detailed:
                    return detailed
            logger.debug("Using search result data for %s", rna_id or "unknown")
            return self._rna_data_to_dict(rna_id or "", first_result)

        except requests.RequestException as e:
            logger.error("Network error searching keyword '%s': %s", keyword, e)
            return None
        except Exception as e:
            logger.error("Unexpected error searching keyword '%s': %s", keyword, e)
            return None

    def _local_blast(self, seq: str, threshold: float) -> Optional[str]:
        """Perform local BLAST search using local BLAST database."""
        try:
            with tempfile.NamedTemporaryFile(mode="w+", suffix=".fa", delete=False) as tmp:
                tmp.write(f">query\n{seq}\n")
                tmp_name = tmp.name

            cmd = [
                "blastn", "-db", self.local_blast_db, "-query", tmp_name,
                "-evalue", str(threshold), "-max_target_seqs", "1", "-outfmt", "6 sacc"
            ]
            logger.debug("Running local blastn for RNA: %s", " ".join(cmd))
            out = subprocess.check_output(cmd, text=True).strip()
            os.remove(tmp_name)
            return out.split("\n", maxsplit=1)[0] if out else None
        except Exception as exc:
            logger.error("Local blastn failed: %s", exc)
            return None

    def get_by_fasta(self, sequence: str, threshold: float = 0.01) -> Optional[dict]:
        """
        Search RNAcentral with an RNA sequence.
        Tries local BLAST first if enabled, falls back to RNAcentral API.
        Unified approach: Find RNA ID from sequence search, then call get_by_rna_id() for complete information.
        :param sequence: RNA sequence (FASTA format or raw sequence).
        :param threshold: E-value threshold for BLAST search.
        :return: A dictionary containing complete RNA information or None if not found.
        """
        def _extract_sequence(sequence: str) -> Optional[str]:
            """Extract and normalize RNA sequence from input."""
            if sequence.startswith(">"):
                seq_lines = sequence.strip().split("\n")
                seq = "".join(seq_lines[1:])
            else:
                seq = sequence.strip().replace(" ", "").replace("\n", "")
            return seq if seq and re.fullmatch(r"[AUCGN\s]+", seq, re.I) else None

        try:
            seq = _extract_sequence(sequence)
            if not seq:
                logger.error("Empty or invalid RNA sequence provided.")
                return None

            # Try local BLAST first if enabled
            if self.use_local_blast:
                accession = self._local_blast(seq, threshold)
                if accession:
                    logger.debug("Local BLAST found accession: %s", accession)
                    return self.get_by_rna_id(accession)

            # Fall back to RNAcentral API if local BLAST didn't find result
            logger.debug("Falling back to RNAcentral API.")

            md5_hash = self._calculate_md5(seq)
            search_url = f"{self.base_url}/rna"
            params = {"md5": md5_hash, "format": "json"}

            resp = requests.get(search_url, params=params, headers=self.headers, timeout=60)
            resp.raise_for_status()

            search_results = resp.json()
            results = search_results.get("results", [])

            if not results:
                logger.info("No exact match found in RNAcentral for sequence")
                return None
            rna_id = results[0].get("rnacentral_id")
            if not rna_id:
                logger.error("No RNAcentral ID found in search results.")
                return None
            return self.get_by_rna_id(rna_id)
        except Exception as e:
            logger.error("Sequence search failed: %s", e)
            return None

    @retry(
        stop=stop_after_attempt(3),
        wait=wait_exponential(multiplier=1, min=2, max=10),
        retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError)),
        reraise=True,
    )
    async def search(self, query: str, threshold: float = 0.1, **kwargs) -> Optional[Dict]:
        """Search RNAcentral with either an RNAcentral ID, keyword, or RNA sequence."""
        if not query or not isinstance(query, str):
            logger.error("Empty or non-string input.")
            return None

        query = query.strip()
        logger.debug("RNAcentral search query: %s", query)

        loop = asyncio.get_running_loop()

        # check if RNA sequence (AUCG characters, contains U)
        if query.startswith(">") or (
            re.fullmatch(r"[AUCGN\s]+", query, re.I) and "U" in query.upper()
        ):
            result = await loop.run_in_executor(_get_pool(), self.get_by_fasta, query, threshold)
        # check if RNAcentral ID (typically starts with URS)
        elif re.fullmatch(r"URS\d+", query, re.I):
            result = await loop.run_in_executor(_get_pool(), self.get_by_rna_id, query)
        else:
            # otherwise treat as keyword
            result = await loop.run_in_executor(_get_pool(), self.get_best_hit, query)

        if result:
            result["_search_query"] = query
        return result