generated from erangel1/generic-template
145 lines
4.4 KiB
Python
145 lines
4.4 KiB
Python
"""
|
|
Pydantic v2 input validation schemas for the LabGraph API.
|
|
|
|
These schemas validate incoming request data at the API boundary BEFORE
|
|
it reaches the Django ORM. A ValidationError here is converted to an
|
|
HTTP 400 in the view layer (Phase 2).
|
|
|
|
Keeping validation here (not in DRF serializers) means the same schemas
|
|
can be reused by Celery discovery tasks when upserting auto-discovered nodes.
|
|
"""
|
|
|
|
import ipaddress
|
|
import re
|
|
import uuid
|
|
from typing import Any
|
|
|
|
from pydantic import BaseModel, Field, field_validator
|
|
|
|
_MAC_PATTERN = re.compile(r"^([0-9a-fA-F]{2}:){5}[0-9a-fA-F]{2}$")
|
|
|
|
VALID_NODE_TYPES = {
|
|
"location", "hardware", "hypervisor", "vm",
|
|
"container", "application", "network_device",
|
|
}
|
|
|
|
VALID_EDGE_TYPES = {"parent_child", "network", "dependency", "physical"}
|
|
|
|
VALID_DISCOVERY_SOURCES = {"manual", "proxmox", "nmap", "snmp"}
|
|
|
|
|
|
class NodeCreateSchema(BaseModel):
|
|
"""Validates POST /api/v1/nodes/ request body."""
|
|
|
|
label: str = Field(min_length=1, max_length=255)
|
|
node_type: str
|
|
ip_address: str | None = None
|
|
mac_address: str | None = None
|
|
wattage: float | None = Field(default=None, gt=0)
|
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
discovery_source: str = "manual"
|
|
external_id: str | None = Field(default=None, max_length=255)
|
|
|
|
@field_validator("node_type")
|
|
@classmethod
|
|
def validate_node_type(cls, v: str) -> str:
|
|
if v not in VALID_NODE_TYPES:
|
|
raise ValueError(f"node_type must be one of: {sorted(VALID_NODE_TYPES)}")
|
|
return v
|
|
|
|
@field_validator("ip_address")
|
|
@classmethod
|
|
def validate_ip_address(cls, v: str | None) -> str | None:
|
|
if v is None:
|
|
return v
|
|
try:
|
|
ipaddress.ip_address(v)
|
|
except ValueError:
|
|
raise ValueError(f"Invalid IP address: {v!r}")
|
|
return v
|
|
|
|
@field_validator("mac_address")
|
|
@classmethod
|
|
def validate_mac_address(cls, v: str | None) -> str | None:
|
|
if v is None:
|
|
return v
|
|
if not _MAC_PATTERN.match(v):
|
|
raise ValueError("MAC address must be in aa:bb:cc:dd:ee:ff format")
|
|
return v.lower()
|
|
|
|
@field_validator("discovery_source")
|
|
@classmethod
|
|
def validate_discovery_source(cls, v: str) -> str:
|
|
if v not in VALID_DISCOVERY_SOURCES:
|
|
raise ValueError(f"discovery_source must be one of: {sorted(VALID_DISCOVERY_SOURCES)}")
|
|
return v
|
|
|
|
|
|
class NodeUpdateSchema(NodeCreateSchema):
|
|
"""Validates PATCH /api/v1/nodes/<id>/ — all fields optional."""
|
|
|
|
label: str = Field(default=None, min_length=1, max_length=255) # type: ignore[assignment]
|
|
node_type: str | None = None # type: ignore[assignment]
|
|
discovery_source: str = "manual"
|
|
|
|
|
|
class EdgeCreateSchema(BaseModel):
|
|
"""Validates POST /api/v1/edges/ request body."""
|
|
|
|
source: uuid.UUID
|
|
target: uuid.UUID
|
|
edge_type: str
|
|
weight: float = Field(default=1.0, gt=0)
|
|
label: str = Field(default="", max_length=255)
|
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
|
|
@field_validator("edge_type")
|
|
@classmethod
|
|
def validate_edge_type(cls, v: str) -> str:
|
|
if v not in VALID_EDGE_TYPES:
|
|
raise ValueError(f"edge_type must be one of: {sorted(VALID_EDGE_TYPES)}")
|
|
return v
|
|
|
|
@field_validator("target")
|
|
@classmethod
|
|
def source_and_target_differ(cls, v: uuid.UUID, info: Any) -> uuid.UUID:
|
|
if "source" in info.data and v == info.data["source"]:
|
|
raise ValueError("source and target must be different nodes")
|
|
return v
|
|
|
|
|
|
class NetworkCreateSchema(BaseModel):
|
|
"""Validates POST /api/v1/networks/ request body."""
|
|
|
|
name: str = Field(min_length=1, max_length=255)
|
|
cidr: str
|
|
vlan_id: int | None = Field(default=None, ge=1, le=4094)
|
|
gateway: str | None = None
|
|
description: str = ""
|
|
|
|
@field_validator("cidr")
|
|
@classmethod
|
|
def validate_cidr(cls, v: str) -> str:
|
|
try:
|
|
ipaddress.ip_network(v, strict=False)
|
|
except ValueError:
|
|
raise ValueError(f"Invalid CIDR notation: {v!r}")
|
|
return v
|
|
|
|
@field_validator("gateway")
|
|
@classmethod
|
|
def validate_gateway(cls, v: str | None) -> str | None:
|
|
if v is None:
|
|
return v
|
|
try:
|
|
ipaddress.ip_address(v)
|
|
except ValueError:
|
|
raise ValueError(f"Invalid gateway IP address: {v!r}")
|
|
return v
|
|
|
|
|
|
class WikiPageUpdateSchema(BaseModel):
|
|
"""Validates PATCH /api/v1/wiki/<node_id>/ request body."""
|
|
|
|
content: str = ""
|