from typing import Literal
from typing import Optional
from typing import Tuple
import gemmi
from esrf_pathlib import ESRFPath
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
[docs]
class SpaceGroupModel(BaseModel):
sg_name: str
[docs]
@field_validator("sg_name")
@classmethod
def normalize_sg(cls, input_sg: str) -> str:
# Lookup the group
sg = gemmi.find_spacegroup_by_name(input_sg)
if not sg:
raise ValueError(f"Unknown space group: {input_sg}")
# Get the standard short name
short_name = sg.short_name()
# FIX: If the user explicitly asked for 'R', and Gemmi returned 'H'
# (or vice versa), respect the user's lattice choice if valid.
if input_sg.startswith("R") and short_name.startswith("H"):
short_name = short_name.replace("H", "R", 1)
elif input_sg.startswith("H") and short_name.startswith("R"):
short_name = short_name.replace("R", "H", 1)
return short_name
[docs]
class DataCollectionModel(BaseModel):
beamline: str = Field(
...,
description="Name of beamline where data were collected",
examples=["id30a1", "id23eh1"],
)
is_finesliced: bool = Field(
..., description="Fine-sliced diffraction data", examples=[True, False]
)
detector_type: str = Field(
...,
description="Type of the detector used",
examples=["eiger9m", "pilatus4_4m"],
)
detector_distance: float = Field(
..., gt=0, description="Crystal to detector distance in mm", examples=[0.242]
)
wavelength: float = Field(
..., gt=0, description="Wavelength in Angstroms", examples=[0.989]
)
beam_position_x: float = Field(
...,
description="x coordinate of beam center in pixels",
examples=[1023.2],
)
beam_position_y: float = Field(
...,
description="y coordinate of beam center in pixels",
examples=[1143.9],
)
num_wedges: int = Field(
..., ge=1, description="Number of data collection wedges", examples=[4]
)
rotation_angle_between_wedges: float = Field(
...,
ge=0,
description="Rotation angle between each data collection wedge",
examples=[90],
)
oscillation_width: float = Field(
...,
ge=0,
description="Rotation angle for one image",
examples=[1.0],
)
num_images_per_wedge: int = Field(
...,
ge=1,
description="Number of images per data collection wedge",
examples=[10],
)
[docs]
class ProcessingPlanModel(BaseModel):
strategy_type: Literal["full", "fast"] = Field(
"full",
description="Data collection strategy",
examples=["full", "fast"],
)
forced_space_group: Optional[str] = Field(
None,
description="Forced space group",
examples=["P1", "P212121"],
)
[docs]
@field_validator("forced_space_group")
@classmethod
def validate_sg(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return v
# We pass 'v' into our existing Model to trigger its validation logic
# and return the normalized 'sg_name' string back to the field.
return SpaceGroupModel(sg_name=v).sg_name
forced_cell: Optional[Tuple[float, float, float, float, float, float]] = Field(
None,
description="Forced unit cell in format [a, b, c, alpha, beta, gamma]",
examples=[[37, 54, 67, 90, 90, 90]],
)
aimed_resolution: Optional[float] = Field(
None,
gt=0,
description="Aimed data collection resolution in Angstroms",
examples=[2.0],
)
aimed_multiplicity: Optional[float] = Field(
None, gt=0, description="Aimed data collection multiplicity", examples=[4.0]
)
aimed_completeness: Optional[float] = Field(
None,
ge=0,
le=100,
description="Aimed data collection completeness in %",
examples=[90.0],
)
anomalous_data: bool = Field(
False, description="Anomalous data collection strategy", examples=[True, False]
)
[docs]
class WorkflowParametersModel(BaseModel, arbitrary_types_allowed=True):
working_directory: ESRFPath = Field(None, description="Workflow working directory")