"""CRUD Blueprint for automatic RESTful API generation with Flask-Smorest.
This module provides a Blueprint subclass that automatically generates
RESTful CRUD (Create, Read, Update, Delete) endpoints for SQLAlchemy models
with Marshmallow schemas.
"""
import enum
import uuid
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from http import HTTPStatus
from importlib import import_module
from typing import TYPE_CHECKING, Any, TypedDict
import sqlalchemy as sa
from flask.views import MethodView
from flask_smorest import Blueprint
from flask_sqlalchemy.session import Session
from marshmallow import RAISE, Schema
from marshmallow_sqlalchemy import SQLAlchemySchema
from sqlalchemy.orm import scoped_session
from flask_more_smorest.sqla.base_model import BaseModel
from ..utils import convert_snake_to_camel
from .blueprint_operationid import BlueprintOperationIdMixin
from .pagination import CRUDPaginationMixin
from .query_filtering import generate_filter_schema, get_statements_from_filters
if TYPE_CHECKING:
from flask_smorest.pagination import PaginationParameters
[docs]
class CRUDMethod(enum.StrEnum):
"""Standard CRUD operations supported by CRUDBlueprint."""
INDEX = "INDEX"
GET = "GET"
POST = "POST"
PATCH = "PATCH"
DELETE = "DELETE"
[docs]
class MethodConfig(TypedDict, total=False):
"""Configuration for a specific CRUD method."""
schema: type[Schema] | str
arg_schema: type[Schema] | str
admin_only: bool
public: bool
MethodConfigMapping = Mapping[CRUDMethod, MethodConfig | bool]
[docs]
def resolve_schema(
schema_candidate: type[Schema] | Schema | str | None,
schema_import_path: str,
default_schema: type[Schema] | None = None,
*,
context: str = "",
) -> type[Schema]:
"""Unified schema resolution for all contexts.
Resolves a schema reference (class, instance, or string name) to a Schema class.
This is the single source of truth for schema resolution throughout the blueprint.
Args:
schema_candidate: Schema to resolve - can be:
- A Schema class: returned directly
- A Schema instance: returns its class
- A string: imports from schema_import_path
- None: returns default_schema if provided
schema_import_path: Module path to import string schemas from
default_schema: Fallback schema if candidate is None
context: Description for error messages (e.g., "PATCH arg_schema")
Returns:
Resolved Schema class
Raises:
ValueError: If schema cannot be resolved or imported
TypeError: If resolved value is not a Schema subclass
Example:
>>> schema_cls = resolve_schema("UserSchema", "myapp.schemas")
>>> schema_cls = resolve_schema(UserSchema, "", default_schema=BaseSchema)
"""
if schema_candidate is None:
if default_schema is None:
context_msg = f" for {context}" if context else ""
raise ValueError(f"No schema provided{context_msg} and no default available")
return default_schema
if isinstance(schema_candidate, str):
try:
schema_module = import_module(schema_import_path)
resolved = getattr(schema_module, schema_candidate)
except ImportError as e:
raise ValueError(f"Could not import module '{schema_import_path}' for schema '{schema_candidate}'.") from e
except AttributeError as e:
raise ValueError(f"Could not find schema '{schema_candidate}' in '{schema_import_path}'.") from e
# Validate imported schema is actually a Schema class
if not isinstance(resolved, type) or not issubclass(resolved, Schema):
context_msg = f" for {context}" if context else ""
raise TypeError(f"Resolved schema{context_msg} must be a Schema subclass, got {type(resolved).__name__}.")
return resolved
if isinstance(schema_candidate, type) and issubclass(schema_candidate, Schema):
return schema_candidate
if isinstance(schema_candidate, Schema):
return schema_candidate.__class__
# Defensive guard for unexpected types at runtime
# Type checker correctly marks this as unreachable given the type hints,
# but we keep it for robustness if called with unexpected types
context_msg = f" for {context}" if context else "" # type: ignore[unreachable]
raise TypeError(
f"Schema{context_msg} must be a string, Schema subclass, or Schema instance; "
f"got {type(schema_candidate).__name__}."
)
[docs]
@dataclass
class CRUDConfig:
"""Configuration object for CRUD blueprint setup."""
name: str
url_prefix: str
import_name: str
model_cls: type[BaseModel]
model_name: str
schema_cls: type[Schema]
schema_name: str
schema_import_path: str
model_import_path: str
res_id_name: str
res_id_param_name: str
methods: dict[CRUDMethod, MethodConfig]
[docs]
class CRUDBlueprint( # pyright: ignore[reportIncompatibleMethodOverride]
CRUDPaginationMixin,
BlueprintOperationIdMixin,
Blueprint,
):
"""Blueprint subclass that automatically registers CRUD routes.
This class extends Flask-Smorest Blueprint to provide automatic CRUD
(Create, Read, Update, Delete) operations for SQLAlchemy models.
It automatically generates RESTful endpoints based on the provided
model and schema configuration.
Args:
name: Blueprint name (first positional arg)
import_name: Import name (second positional arg)
model: Model class or string name to use
schema: Schema class or string name to use
methods: Controls which CRUD methods to enable. Can be:
- List of CRUDMethod: Only these methods are enabled
- Dict mapping CRUDMethod to config: All methods enabled by default,
unless explicitly set to False or present in skip_methods.
Dict values can be:
- True: Enable with defaults
- False: Disable this method
- MethodConfig dict: Enable with custom configuration
skip_methods: List of CRUDMethod to explicitly disable.
Applied after methods resolution. Useful when using dict-style
methods to disable specific defaults.
*pargs, **kwargs: Additional keyword arguments for CRUD configuration
Examples:
Basic usage (all methods enabled):
.. code-block:: python
from myapp.models import Product
blueprint = CRUDBlueprint(
"products",
__name__,
model=Product, # Use class (preferred)
schema=Product.Schema, # Auto-generated schema
)
Enable only specific methods:
.. code-block:: python
blueprint = CRUDBlueprint(
"products",
__name__,
model=Product,
schema=Product.Schema,
methods=[CRUDMethod.INDEX, CRUDMethod.GET], # Read-only
)
Disable specific methods:
.. code-block:: python
blueprint = CRUDBlueprint(
"products",
__name__,
model=Product,
schema=Product.Schema,
skip_methods=[CRUDMethod.DELETE], # All except delete
)
Advanced configuration (custom schemas, admin-only):
.. code-block:: python
blueprint = CRUDBlueprint(
"products",
__name__,
model=Product,
schema=Product.Schema,
methods={
CRUDMethod.POST: {"schema": ProductCreateSchema},
CRUDMethod.DELETE: {"admin_only": True},
CRUDMethod.PATCH: False, # Explicitly disable
},
)
"""
_db_session: Session | scoped_session[Session]
_config: CRUDConfig
[docs]
def __init__(
self,
name: str,
import_name: str,
model: type[BaseModel] | str | None = None,
schema: type[Schema] | str | None = None,
model_import_name: str | None = None,
schema_import_name: str | None = None,
res_id: str = "id",
res_id_param: str | None = None,
methods: list[CRUDMethod] | MethodConfigMapping = list(CRUDMethod),
skip_methods: list[CRUDMethod] | None = None,
default_page_size: int | None = 20,
db_session: Session | scoped_session[Session] | None = None,
static_folder: str | None = None,
static_url_path: str | None = None,
template_folder: str | None = None,
url_prefix: str | None = None,
subdomain: str | None = None,
url_defaults: dict[str, Any] | None = None,
root_path: str | None = None,
cli_group: str | None = None,
) -> None:
"""Initialize CRUD blueprint with model and schema configuration."""
if db_session is None:
from flask_more_smorest.sqla import db
self._db_session = db.session
else:
self._db_session = db_session
self._default_page_size = default_page_size
self._config = self._build_config(
name=name,
import_name=import_name,
model=model,
schema=schema,
model_import_name=model_import_name,
schema_import_name=schema_import_name,
res_id=res_id,
res_id_param=res_id_param,
methods=methods,
skip_methods=skip_methods,
url_prefix=url_prefix,
)
super().__init__(
name,
import_name,
static_folder=static_folder,
static_url_path=static_url_path,
template_folder=template_folder,
url_prefix=url_prefix or self._config.url_prefix,
subdomain=subdomain,
url_defaults=url_defaults,
root_path=root_path,
cli_group=cli_group,
)
update_schema = self._prepare_update_schema(self._config)
self._register_crud_routes(self._config, update_schema)
def _build_config(
self,
name: str,
import_name: str,
model: type[BaseModel] | str | None,
schema: type[Schema] | str | None,
model_import_name: str | None,
schema_import_name: str | None,
res_id: str,
res_id_param: str | None,
methods: list[CRUDMethod] | MethodConfigMapping,
skip_methods: list[CRUDMethod] | None,
url_prefix: str | None,
) -> CRUDConfig:
"""Build and validate configuration.
Method resolution order:
1. Normalize methods parameter (list → dict, or process dict with defaults)
2. Apply skip_methods to remove explicitly disabled methods
3. Return final enabled methods configuration
Args:
name: Blueprint name
import_name: Import name for the blueprint
model: Model class or string name
schema: Schema class or string name
model_import_name: Module path to import model from
schema_import_name: Module path to import schema from
res_id: Name of the resource ID field on the model
res_id_param: Name for the ID parameter in URL routes
methods: Methods configuration (list or dict)
skip_methods: Methods to explicitly disable after normalization
url_prefix: URL prefix for the blueprint
Returns:
Validated CRUDConfig object
"""
resolved_url_prefix: str = url_prefix or f"/{name}/"
resolved_model_import_path: str = model_import_name or ".".join([*import_name.split(".")[:-1], "models"])
resolved_schema_import_path: str = schema_import_name or ".".join([*import_name.split(".")[:-1], "schemas"])
# Resolve model class
model_cls = self._resolve_model_class(
model or convert_snake_to_camel(name.capitalize()),
resolved_model_import_path,
)
# Resolve schema class using unified resolver
schema_cls = resolve_schema(
schema,
resolved_schema_import_path,
default_schema=model_cls.Schema,
context="blueprint schema",
)
res_id_param_name: str = res_id_param or f"{name.lower()}_id"
# Step 1: Normalize methods into a dict with configs
normalized_methods = self._normalize_methods(methods)
# Step 2: Apply skip_methods to remove explicitly disabled methods
# This happens after normalization so it works consistently regardless
# of whether methods was a list or dict
if skip_methods:
for method_to_skip in skip_methods:
method_enum = CRUDMethod(method_to_skip)
normalized_methods.pop(method_enum, None)
# Warn if dict already disabled this method (redundant)
if isinstance(methods, dict) and methods.get(method_enum) is False:
import warnings
warnings.warn(
f"Method {method_enum.value} is set to False in 'methods' dict "
f"and also appears in 'skip_methods'. The skip_methods entry is redundant.",
UserWarning,
stacklevel=3,
)
return CRUDConfig(
name=name,
url_prefix=resolved_url_prefix,
import_name=import_name,
model_cls=model_cls,
model_name=model_cls.__name__,
schema_cls=schema_cls,
schema_name=schema_cls.__name__,
schema_import_path=resolved_schema_import_path,
model_import_path=resolved_model_import_path,
res_id_name=res_id,
res_id_param_name=res_id_param_name,
methods=normalized_methods,
)
def _normalize_methods(
self,
methods_raw: list[CRUDMethod] | MethodConfigMapping,
) -> dict[CRUDMethod, MethodConfig]:
"""Normalize different method inputs into a standard dict.
Behavior:
- If methods_raw is a list: Only those methods are enabled (explicit whitelist)
- If methods_raw is a dict: All methods are enabled by default, unless:
- Explicitly set to False in the dict
- Will be removed later by skip_methods
Args:
methods_raw: Either a list of methods to enable, or a dict mapping
methods to their configuration
Returns:
Normalized dict mapping enabled methods to their configuration
Raises:
TypeError: If methods_raw is not a list or dict, or if dict values
are invalid
"""
normalized: dict[CRUDMethod, MethodConfig] = {}
if isinstance(methods_raw, list):
# List mode: explicit whitelist - only these methods are enabled
for item in methods_raw:
key = CRUDMethod(item)
normalized[key] = {}
return normalized
if not isinstance(methods_raw, Mapping):
raise TypeError(
f"CRUDBlueprint 'methods' argument must be a list or a dict, got {type(methods_raw).__name__}"
)
# Dict mode: all methods enabled by default, process overrides
# First, enable all methods with default config
for method in CRUDMethod:
normalized[method] = {}
# Then apply dict overrides
for method, config in methods_raw.items():
key = CRUDMethod(method)
if config is True:
# Explicitly enabled with defaults (redundant but allowed)
normalized[key] = {}
elif config is False:
# Explicitly disabled - remove from normalized dict
normalized.pop(key, None)
elif isinstance(config, dict):
# Custom configuration provided
normalized[key] = config
else:
raise TypeError(
f"CRUDBlueprint method config for {method} must be a dict, True, or False; "
f"got {type(config).__name__}"
)
return normalized
@staticmethod
def _resolve_model_class(
model_candidate: type[BaseModel] | str,
model_import_path: str,
) -> type[BaseModel]:
"""Resolve a model reference (class or string name) to a BaseModel subclass.
Args:
model_candidate: Model class or string name to resolve
model_import_path: Module path to import string models from
Returns:
Resolved BaseModel subclass
Raises:
ValueError: If model cannot be imported or resolved
"""
if isinstance(model_candidate, str):
try:
model_module = import_module(model_import_path)
model_cls = getattr(model_module, model_candidate)
except ImportError as e:
raise ValueError(f"Could not import module '{model_import_path}' for model '{model_candidate}'.") from e
except AttributeError as e:
raise ValueError(f"Could not find model '{model_candidate}' in '{model_import_path}'.") from e
# Validate imported model is actually a BaseModel subclass
if not isinstance(model_cls, type) or not issubclass(model_cls, BaseModel):
raise ValueError(
f"Imported '{model_candidate}' from '{model_import_path}' "
f"is not a BaseModel subclass, got {type(model_cls).__name__}."
)
return model_cls
if isinstance(model_candidate, type) and issubclass(model_candidate, BaseModel):
return model_candidate
raise ValueError(
f"CRUDBlueprint 'model' argument must be a string or a BaseModel subclass, "
f"got {type(model_candidate).__name__}."
)
def _resolve_schema_class(
self,
schema_candidate: type[Schema] | Schema | str,
*,
config: CRUDConfig,
method: CRUDMethod,
) -> type[Schema]:
"""Resolve a schema reference for a specific CRUD method.
This is a convenience wrapper around the module-level resolve_schema
function, providing method-specific context for error messages.
Args:
schema_candidate: Schema to resolve
config: CRUD configuration with import paths
method: The CRUD method this schema is for
Returns:
Resolved Schema class
"""
return resolve_schema(
schema_candidate,
config.schema_import_path,
context=f"{method.value} method",
)
def _prepare_update_schema(
self, config: CRUDConfig
) -> Schema | type[Schema] | SQLAlchemySchema | type[SQLAlchemySchema]:
"""Create update schema for PATCH operations.
If an explicit arg_schema is provided in PATCH method config, it's used.
Otherwise, creates a partial version of the default schema.
Args:
config: Configuration object
Returns:
Update schema instance or class
"""
patch_config = config.methods.get(CRUDMethod.PATCH, {})
update_schema_arg = patch_config.get("arg_schema")
if update_schema_arg is not None:
# Explicit patch schema provided - use unified resolver
return resolve_schema(
update_schema_arg,
config.schema_import_path,
context="PATCH arg_schema",
)
# Create partial schema from default
# NOTE: the following will trigger a warning in apispec if no custom resolver is set
update_schema = config.schema_cls(partial=True)
if isinstance(update_schema, SQLAlchemySchema):
update_schema._load_instance = False
return update_schema
def _register_crud_routes(
self,
config: CRUDConfig,
update_schema: Schema | type[Schema],
) -> None:
"""Register all CRUD routes for the blueprint.
Args:
config: Configuration object
update_schema: Update schema for PATCH operations
"""
id_type = str(getattr(config.model_cls, config.res_id_name).type).lower()
if id_type.startswith("char"):
id_type = "uuid"
model_cls: type[BaseModel] = config.model_cls
schema_cls: type[Schema] = config.schema_cls
if CRUDMethod.INDEX in config.methods or CRUDMethod.POST in config.methods:
# Initialize variables to avoid "possibly unbound" errors
index_schema_class: type[Schema] | None = None
query_filter_schema: type[Schema] | None = None
if CRUDMethod.INDEX in config.methods:
index_schema_candidate = config.methods[CRUDMethod.INDEX].get("schema", schema_cls)
index_schema_class = self._resolve_schema_class(
index_schema_candidate, config=config, method=CRUDMethod.INDEX
)
query_filter_schema = generate_filter_schema(base_schema=index_schema_class)
class GenericIndex(MethodView):
"""Index/Post endpoints."""
if CRUDMethod.INDEX in config.methods:
if self._default_page_size is not None:
@self.arguments(query_filter_schema, location="query", unknown=RAISE) # pyright: ignore[reportArgumentType]
@self.response(HTTPStatus.OK, index_schema_class(many=True)) # type: ignore[misc] # pyright: ignore[reportArgumentType, reportOptionalCall]
@self.paginate(page_size=self._default_page_size)
@self.doc(operationId=f"list{config.model_name}")
def get(
_self, # NOTE: using _self to avoid collision with outer self
filters: dict,
pagination_parameters: "PaginationParameters",
**kwargs: Any,
) -> Sequence[BaseModel]:
"""Fetch all resources.
kwargs might contains path parameters to filter by (eg /user/<uuid:user_id>/roles/)
"""
stmts = get_statements_from_filters(filters, model=model_cls)
base_query = sa.select(model_cls).filter_by(**kwargs).filter(*stmts)
count_query = sa.select(sa.func.count()).select_from(base_query.subquery())
total_items = self._db_session.scalar(count_query)
pagination_parameters.item_count = total_items # pyright: ignore[reportAttributeAccessIssue]
if pagination_parameters.page_size > 0:
paginated_query = base_query.limit(pagination_parameters.page_size).offset(
pagination_parameters.page_size * (pagination_parameters.page - 1)
)
else:
paginated_query = base_query
return self._db_session.execute(paginated_query).scalars().all()
else:
@self.arguments(query_filter_schema, location="query", unknown=RAISE) # pyright: ignore[reportArgumentType]
@self.response(HTTPStatus.OK, index_schema_class(many=True)) # type: ignore[misc] # pyright: ignore[reportArgumentType, reportOptionalCall]
@self.doc(operationId=f"list{config.model_name}")
def get(
_self,
filters: dict,
**kwargs: Any,
) -> Sequence[BaseModel]:
"""Fetch all resources (no pagination)."""
stmts = get_statements_from_filters(filters, model=model_cls)
base_query = sa.select(model_cls).filter_by(**kwargs).filter(*stmts)
return self._db_session.execute(base_query).scalars().all()
if CRUDMethod.POST in config.methods:
@self.arguments(config.methods[CRUDMethod.POST].get("schema", schema_cls))
@self.response(
HTTPStatus.OK,
config.methods[CRUDMethod.POST].get("schema", schema_cls),
)
@self.doc(
responses={
HTTPStatus.NOT_FOUND: {"description": f"{config.name} resource not found"},
HTTPStatus.CONFLICT: {"description": "DB error."},
},
operationId=f"create{config.model_name}",
)
def post(
_self,
new_object: BaseModel,
**kwargs: str | int | float | bool | bytes | None,
) -> BaseModel:
"""Create and return new resource."""
new_object.update(commit=True, **kwargs)
return new_object
self._configure_endpoint(
GenericIndex,
"get",
f"Fetch all {config.name} resources.",
config.methods.get(CRUDMethod.INDEX, {}),
)
self._configure_endpoint(
GenericIndex,
"post",
f"Create and return new {config.name}.",
config.methods.get(CRUDMethod.POST, {}),
)
self.route("")(GenericIndex)
class GenericCRUD(MethodView):
"""Resource-specific endpoints."""
if CRUDMethod.GET in config.methods:
@self.doc(
responses={HTTPStatus.NOT_FOUND: {"description": f"{config.name} not found"}},
operationId=f"get{config.model_name}",
)
@self.response(
HTTPStatus.OK,
config.methods[CRUDMethod.GET].get("schema", schema_cls),
)
def get(_self, **kwargs: Any) -> BaseModel:
"""Fetch resource by ID."""
kwargs[config.res_id_name] = kwargs.pop(config.res_id_param_name)
return model_cls.get_by_or_404(**kwargs)
if CRUDMethod.PATCH in config.methods:
@self.arguments(update_schema)
@self.doc(
responses={
HTTPStatus.NOT_FOUND: {"description": f"{config.name} not found"},
HTTPStatus.CONFLICT: {"description": "DB error."},
},
operationId=f"update{config.model_name}",
)
@self.response(
HTTPStatus.OK,
config.methods[CRUDMethod.PATCH].get("schema", schema_cls),
)
def patch(_self, payload: dict, **kwargs: str | int | uuid.UUID | bool | None) -> BaseModel:
"""Update resource."""
kwargs[config.res_id_name] = kwargs.pop(config.res_id_param_name)
res = model_cls.get_by_or_404(**kwargs)
res.update(**payload)
return res
if CRUDMethod.DELETE in config.methods:
@self.response(HTTPStatus.NO_CONTENT, description=f"{config.name} deleted")
@self.doc(operationId=f"delete{config.model_name}")
def delete(_self, **kwargs: str | int | uuid.UUID | bool | None) -> tuple[str, int]:
"""Delete resource."""
kwargs[config.res_id_name] = kwargs.pop(config.res_id_param_name)
res = model_cls.get_by_or_404(**kwargs)
res.delete()
return "", HTTPStatus.NO_CONTENT
if "PUT" in config.methods:
raise NotImplementedError("PUT method is not implemented. Use PATCH instead.")
self._configure_endpoint(
GenericCRUD,
"get",
f"Fetch {config.name} by ID.",
config.methods.get(CRUDMethod.GET, {}),
)
self._configure_endpoint(
GenericCRUD,
"patch",
f"Update {config.name} by ID.",
config.methods.get(CRUDMethod.PATCH, {}),
)
self._configure_endpoint(
GenericCRUD,
"delete",
f"Delete {config.name} by ID.",
config.methods.get(CRUDMethod.DELETE, {}),
)
# Only register GenericCRUD if it has at least one method
if any(method in config.methods for method in [CRUDMethod.GET, CRUDMethod.PATCH, CRUDMethod.DELETE]):
self.route(f"<{id_type}:{config.res_id_param_name}>")(GenericCRUD)
def _configure_endpoint(
self,
view_cls: type[MethodView],
method_name: str,
docstring: str,
method_config: MethodConfig,
) -> None:
"""Configure endpoint with docstring and admin decorator if needed.
Args:
view_cls: MethodView class containing the endpoint
method_name: Name of the method to configure
docstring: Docstring to set on the method
method_config: Configuration dict for the method
"""
if hasattr(view_cls, method_name):
method = getattr(view_cls, method_name)
method.__doc__ = docstring
if method_config.get("admin_only", False):
from ..perms import PermsBlueprintMixin
if isinstance(self, PermsBlueprintMixin):
self.admin_endpoint(method)
else:
raise TypeError("Blueprint must inherit from PermsBlueprintMixin to set admin_only endpoint.")