Source code for sirnaforge.validation.utils

"""Validation utilities for data consistency and cross-validation."""

from typing import Any

import pandas as pd

from sirnaforge.models.schemas import OffTargetHitsSchema, ORFValidationSchema, SiRNACandidateSchema
from sirnaforge.models.sirna import DesignParameters, SiRNACandidate
from sirnaforge.utils.logging_utils import get_logger

logger = get_logger(__name__)

# Constants derived from model definitions to avoid duplication
SIRNA_MIN_LENGTH = 19  # From SiRNACandidate.length field constraints
SIRNA_MAX_LENGTH = 23  # From SiRNACandidate.length field constraints


[docs] class ValidationResult: """Container for validation results."""
[docs] def __init__(self, is_valid: bool = True): """Initialize validation result container.""" self.is_valid = is_valid self.errors: list[str] = [] self.warnings: list[str] = [] self.metadata: dict[str, Any] = {}
[docs] def add_error(self, message: str) -> None: """Add a validation error.""" self.errors.append(message) self.is_valid = False
[docs] def add_warning(self, message: str) -> None: """Add a validation warning.""" self.warnings.append(message)
[docs] def add_metadata(self, key: str, value: Any) -> None: """Add metadata to the result.""" self.metadata[key] = value
[docs] def merge(self, other: "ValidationResult") -> None: """Merge another validation result into this one.""" if not other.is_valid: self.is_valid = False self.errors.extend(other.errors) self.warnings.extend(other.warnings) self.metadata.update(other.metadata)
[docs] def summary(self) -> dict[str, Any]: """Get a summary of validation results.""" return { "is_valid": self.is_valid, "error_count": len(self.errors), "warning_count": len(self.warnings), "errors": self.errors, "warnings": self.warnings, "metadata": self.metadata, }
[docs] class ValidationUtils: """Utility functions for data validation."""
[docs] @staticmethod def validate_nucleotide_sequence(sequence: str, allow_ambiguous: bool = True) -> ValidationResult: """Validate nucleotide sequence composition.""" result = ValidationResult() # Define valid characters valid_chars = set("ATCGU") if allow_ambiguous: valid_chars.update("NRYSWKMBDHV-") # Check for invalid characters invalid_chars = set(sequence.upper()) - valid_chars if invalid_chars: result.add_error(f"Invalid nucleotides found: {sorted(invalid_chars)}") # Check for excessive poly-runs for base in "ATCGU": if base * 4 in sequence.upper(): # 4+ consecutive identical bases result.add_warning(f"Poly-{base} run detected in sequence") # Add sequence composition metadata result.add_metadata("length", len(sequence)) result.add_metadata("gc_content", ValidationUtils._calculate_gc_content(sequence)) return result
[docs] @staticmethod def validate_sirna_length(sequence: str) -> ValidationResult: """Validate siRNA sequence length.""" result = ValidationResult() length = len(sequence) if not (19 <= length <= 23): result.add_error(f"siRNA length {length} outside valid range (19-23)") elif length != 21: result.add_warning(f"siRNA length {length} is non-standard (21 is optimal)") result.add_metadata("length", length) return result
[docs] @staticmethod def validate_parameter_consistency(params: DesignParameters) -> ValidationResult: """Validate design parameter consistency.""" result = ValidationResult() # Check filter criteria consistency if params.filters.gc_min > params.filters.gc_max: result.add_error("gc_min cannot be greater than gc_max") if params.filters.gc_max - params.filters.gc_min < 5: result.add_warning("Very narrow GC content range may yield few candidates") # Check scoring weights total_weight = ( params.scoring.asymmetry + params.scoring.gc_content + params.scoring.accessibility + params.scoring.off_target + params.scoring.empirical ) if abs(total_weight - 1.0) > 0.01: result.add_error(f"Scoring weights sum to {total_weight:.3f}, should be 1.0") # Check parameter ranges if params.top_n > 1000: result.add_warning("Large top_n value may impact performance") result.add_metadata("total_weight", total_weight) return result
[docs] @staticmethod def validate_candidate_consistency(candidate: SiRNACandidate) -> ValidationResult: """Validate siRNA candidate internal consistency.""" result = ValidationResult() # Check sequence lengths match — different guide/passenger lengths are allowed # but should be flagged as a warning so downstream processing can handle them. if len(candidate.guide_sequence) != len(candidate.passenger_sequence): result.add_warning("Guide and passenger sequences have different lengths") # Check position is positive if candidate.position <= 0: result.add_error("Position must be positive") # Check GC content is reasonable calculated_gc = ValidationUtils._calculate_gc_content(candidate.guide_sequence) if abs(calculated_gc - candidate.gc_content) > 5.0: result.add_warning( f"Reported GC content ({candidate.gc_content:.1f}%) differs from calculated ({calculated_gc:.1f}%)" ) # Check score ranges if not (0 <= candidate.composite_score <= 100): result.add_error(f"Composite score {candidate.composite_score} outside valid range (0-100)") if not (0 <= candidate.asymmetry_score <= 1): result.add_error(f"Asymmetry score {candidate.asymmetry_score} outside valid range (0-1)") result.add_metadata("calculated_gc", calculated_gc) return result
[docs] @staticmethod def validate_dataframe_schema(df: pd.DataFrame, schema_type: str) -> ValidationResult: """Validate DataFrame against appropriate pandera schema.""" result = ValidationResult() try: if schema_type == "sirna_candidates": df_sirna = SiRNACandidateSchema.validate(df) result.add_metadata("validated_rows", len(df_sirna)) elif schema_type == "orf_validation": df_orf = ORFValidationSchema.validate(df) result.add_metadata("validated_rows", len(df_orf)) elif schema_type == "off_target_hits": df_hits = OffTargetHitsSchema.validate(df) result.add_metadata("validated_rows", len(df_hits)) else: result.add_error(f"Unknown schema type: {schema_type}") except Exception as e: result.add_error(f"Schema validation failed: {str(e)}") logger.error(f"DataFrame schema validation error: {e}") return result
[docs] @staticmethod def validate_transcript_ids_consistency( candidate_ids: set[str], orf_ids: set[str], transcript_ids: set[str] ) -> ValidationResult: """Validate consistency of transcript IDs across datasets.""" result = ValidationResult() # Check for missing IDs candidates_missing = candidate_ids - transcript_ids if candidates_missing: result.add_error(f"Candidates reference unknown transcripts: {candidates_missing}") orf_missing = orf_ids - transcript_ids if orf_missing: result.add_error(f"ORF analysis references unknown transcripts: {orf_missing}") # Check for unused transcripts unused_transcripts = transcript_ids - candidate_ids - orf_ids if unused_transcripts: result.add_warning(f"Transcripts not used in analysis: {unused_transcripts}") result.add_metadata("candidate_count", len(candidate_ids)) result.add_metadata("orf_count", len(orf_ids)) result.add_metadata("transcript_count", len(transcript_ids)) return result
[docs] @staticmethod def validate_biological_constraints(candidate: SiRNACandidate) -> ValidationResult: """Validate bioinformatics-specific constraints.""" result = ValidationResult() # Check for forbidden motifs (simplified examples) forbidden_motifs = ["AAAA", "TTTT", "CCCC", "GGGG"] for motif in forbidden_motifs: if motif in candidate.guide_sequence: result.add_warning(f"Forbidden motif {motif} found in guide sequence") # Check thermodynamic properties if candidate.asymmetry_score < 0.2: result.add_warning("Low asymmetry score may reduce siRNA efficacy") if candidate.paired_fraction > 0.6: result.add_warning("High secondary structure may reduce accessibility") # Check GC content range if candidate.gc_content < 30 or candidate.gc_content > 52: result.add_warning(f"GC content {candidate.gc_content:.1f}% outside optimal range (30-52%)") return result
@staticmethod def _calculate_gc_content(sequence: str) -> float: """Calculate GC content percentage.""" if not sequence: return 0.0 gc_count = sequence.upper().count("G") + sequence.upper().count("C") total_count = len([c for c in sequence.upper() if c in "ATCGU"]) if total_count == 0: return 0.0 return (gc_count / total_count) * 100.0
[docs] @staticmethod def cross_validate_pydantic_pandera() -> ValidationResult: """Cross-validate Pydantic model constraints with Pandera schema constraints.""" result = ValidationResult() # Check siRNA length constraints using constants derived from model definitions pydantic_min_length = SIRNA_MIN_LENGTH pydantic_max_length = SIRNA_MAX_LENGTH pandera_min_length = SIRNA_MIN_LENGTH # From SiRNACandidateSchema.check_sequence_lengths pandera_max_length = SIRNA_MAX_LENGTH if pydantic_min_length != pandera_min_length: result.add_error("Pydantic and Pandera minimum length constraints don't match") if pydantic_max_length != pandera_max_length: result.add_error("Pydantic and Pandera maximum length constraints don't match") # Check GC content ranges pydantic_gc_min = 0.0 # From SiRNACandidate model pydantic_gc_max = 100.0 pandera_gc_min = 0.0 # From schema pandera_gc_max = 100.0 if pydantic_gc_min != pandera_gc_min or pydantic_gc_max != pandera_gc_max: result.add_error("Pydantic and Pandera GC content ranges don't match") result.add_metadata("constraints_checked", ["sequence_length", "gc_content"]) return result