Source code for flask_more_smorest.crud.crud_blueprint

"""CRUD Blueprint for automatic RESTful API generation with Flask-Smorest.

This module provides a Blueprint subclass that automatically generates
RESTful CRUD (Create, Read, Update, Delete) endpoints for SQLAlchemy models
with Marshmallow schemas.
"""

import enum
import uuid
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from http import HTTPStatus
from importlib import import_module
from typing import TYPE_CHECKING, Any, TypedDict

import sqlalchemy as sa
from flask.views import MethodView
from flask_smorest import Blueprint
from flask_sqlalchemy.session import Session
from marshmallow import RAISE, Schema
from marshmallow_sqlalchemy import SQLAlchemySchema
from sqlalchemy.orm import scoped_session

from flask_more_smorest.sqla.base_model import BaseModel

from ..utils import convert_snake_to_camel
from .blueprint_operationid import BlueprintOperationIdMixin
from .pagination import CRUDPaginationMixin
from .query_filtering import generate_filter_schema, get_statements_from_filters

if TYPE_CHECKING:
    from flask_smorest.pagination import PaginationParameters


[docs] class CRUDMethod(enum.StrEnum): """Standard CRUD operations supported by CRUDBlueprint.""" INDEX = "INDEX" GET = "GET" POST = "POST" PATCH = "PATCH" DELETE = "DELETE"
[docs] class MethodConfig(TypedDict, total=False): """Configuration for a specific CRUD method.""" schema: type[Schema] | str arg_schema: type[Schema] | str admin_only: bool public: bool
MethodConfigMapping = Mapping[CRUDMethod, MethodConfig | bool]
[docs] def resolve_schema( schema_candidate: type[Schema] | Schema | str | None, schema_import_path: str, default_schema: type[Schema] | None = None, *, context: str = "", ) -> type[Schema]: """Unified schema resolution for all contexts. Resolves a schema reference (class, instance, or string name) to a Schema class. This is the single source of truth for schema resolution throughout the blueprint. Args: schema_candidate: Schema to resolve - can be: - A Schema class: returned directly - A Schema instance: returns its class - A string: imports from schema_import_path - None: returns default_schema if provided schema_import_path: Module path to import string schemas from default_schema: Fallback schema if candidate is None context: Description for error messages (e.g., "PATCH arg_schema") Returns: Resolved Schema class Raises: ValueError: If schema cannot be resolved or imported TypeError: If resolved value is not a Schema subclass Example: >>> schema_cls = resolve_schema("UserSchema", "myapp.schemas") >>> schema_cls = resolve_schema(UserSchema, "", default_schema=BaseSchema) """ if schema_candidate is None: if default_schema is None: context_msg = f" for {context}" if context else "" raise ValueError(f"No schema provided{context_msg} and no default available") return default_schema if isinstance(schema_candidate, str): try: schema_module = import_module(schema_import_path) resolved = getattr(schema_module, schema_candidate) except ImportError as e: raise ValueError(f"Could not import module '{schema_import_path}' for schema '{schema_candidate}'.") from e except AttributeError as e: raise ValueError(f"Could not find schema '{schema_candidate}' in '{schema_import_path}'.") from e # Validate imported schema is actually a Schema class if not isinstance(resolved, type) or not issubclass(resolved, Schema): context_msg = f" for {context}" if context else "" raise TypeError(f"Resolved schema{context_msg} must be a Schema subclass, got {type(resolved).__name__}.") return resolved if isinstance(schema_candidate, type) and issubclass(schema_candidate, Schema): return schema_candidate if isinstance(schema_candidate, Schema): return schema_candidate.__class__ # Defensive guard for unexpected types at runtime # Type checker correctly marks this as unreachable given the type hints, # but we keep it for robustness if called with unexpected types context_msg = f" for {context}" if context else "" # type: ignore[unreachable] raise TypeError( f"Schema{context_msg} must be a string, Schema subclass, or Schema instance; " f"got {type(schema_candidate).__name__}." )
[docs] @dataclass class CRUDConfig: """Configuration object for CRUD blueprint setup.""" name: str url_prefix: str import_name: str model_cls: type[BaseModel] model_name: str schema_cls: type[Schema] schema_name: str schema_import_path: str model_import_path: str res_id_name: str res_id_param_name: str methods: dict[CRUDMethod, MethodConfig]
[docs] class CRUDBlueprint( # pyright: ignore[reportIncompatibleMethodOverride] CRUDPaginationMixin, BlueprintOperationIdMixin, Blueprint, ): """Blueprint subclass that automatically registers CRUD routes. This class extends Flask-Smorest Blueprint to provide automatic CRUD (Create, Read, Update, Delete) operations for SQLAlchemy models. It automatically generates RESTful endpoints based on the provided model and schema configuration. Args: name: Blueprint name (first positional arg) import_name: Import name (second positional arg) model: Model class or string name to use schema: Schema class or string name to use methods: Controls which CRUD methods to enable. Can be: - List of CRUDMethod: Only these methods are enabled - Dict mapping CRUDMethod to config: All methods enabled by default, unless explicitly set to False or present in skip_methods. Dict values can be: - True: Enable with defaults - False: Disable this method - MethodConfig dict: Enable with custom configuration skip_methods: List of CRUDMethod to explicitly disable. Applied after methods resolution. Useful when using dict-style methods to disable specific defaults. *pargs, **kwargs: Additional keyword arguments for CRUD configuration Examples: Basic usage (all methods enabled): .. code-block:: python from myapp.models import Product blueprint = CRUDBlueprint( "products", __name__, model=Product, # Use class (preferred) schema=Product.Schema, # Auto-generated schema ) Enable only specific methods: .. code-block:: python blueprint = CRUDBlueprint( "products", __name__, model=Product, schema=Product.Schema, methods=[CRUDMethod.INDEX, CRUDMethod.GET], # Read-only ) Disable specific methods: .. code-block:: python blueprint = CRUDBlueprint( "products", __name__, model=Product, schema=Product.Schema, skip_methods=[CRUDMethod.DELETE], # All except delete ) Advanced configuration (custom schemas, admin-only): .. code-block:: python blueprint = CRUDBlueprint( "products", __name__, model=Product, schema=Product.Schema, methods={ CRUDMethod.POST: {"schema": ProductCreateSchema}, CRUDMethod.DELETE: {"admin_only": True}, CRUDMethod.PATCH: False, # Explicitly disable }, ) """ _db_session: Session | scoped_session[Session] _config: CRUDConfig
[docs] def __init__( self, name: str, import_name: str, model: type[BaseModel] | str | None = None, schema: type[Schema] | str | None = None, model_import_name: str | None = None, schema_import_name: str | None = None, res_id: str = "id", res_id_param: str | None = None, methods: list[CRUDMethod] | MethodConfigMapping = list(CRUDMethod), skip_methods: list[CRUDMethod] | None = None, default_page_size: int | None = 20, db_session: Session | scoped_session[Session] | None = None, static_folder: str | None = None, static_url_path: str | None = None, template_folder: str | None = None, url_prefix: str | None = None, subdomain: str | None = None, url_defaults: dict[str, Any] | None = None, root_path: str | None = None, cli_group: str | None = None, ) -> None: """Initialize CRUD blueprint with model and schema configuration.""" if db_session is None: from flask_more_smorest.sqla import db self._db_session = db.session else: self._db_session = db_session self._default_page_size = default_page_size self._config = self._build_config( name=name, import_name=import_name, model=model, schema=schema, model_import_name=model_import_name, schema_import_name=schema_import_name, res_id=res_id, res_id_param=res_id_param, methods=methods, skip_methods=skip_methods, url_prefix=url_prefix, ) super().__init__( name, import_name, static_folder=static_folder, static_url_path=static_url_path, template_folder=template_folder, url_prefix=url_prefix or self._config.url_prefix, subdomain=subdomain, url_defaults=url_defaults, root_path=root_path, cli_group=cli_group, ) update_schema = self._prepare_update_schema(self._config) self._register_crud_routes(self._config, update_schema)
def _build_config( self, name: str, import_name: str, model: type[BaseModel] | str | None, schema: type[Schema] | str | None, model_import_name: str | None, schema_import_name: str | None, res_id: str, res_id_param: str | None, methods: list[CRUDMethod] | MethodConfigMapping, skip_methods: list[CRUDMethod] | None, url_prefix: str | None, ) -> CRUDConfig: """Build and validate configuration. Method resolution order: 1. Normalize methods parameter (list → dict, or process dict with defaults) 2. Apply skip_methods to remove explicitly disabled methods 3. Return final enabled methods configuration Args: name: Blueprint name import_name: Import name for the blueprint model: Model class or string name schema: Schema class or string name model_import_name: Module path to import model from schema_import_name: Module path to import schema from res_id: Name of the resource ID field on the model res_id_param: Name for the ID parameter in URL routes methods: Methods configuration (list or dict) skip_methods: Methods to explicitly disable after normalization url_prefix: URL prefix for the blueprint Returns: Validated CRUDConfig object """ resolved_url_prefix: str = url_prefix or f"/{name}/" resolved_model_import_path: str = model_import_name or ".".join([*import_name.split(".")[:-1], "models"]) resolved_schema_import_path: str = schema_import_name or ".".join([*import_name.split(".")[:-1], "schemas"]) # Resolve model class model_cls = self._resolve_model_class( model or convert_snake_to_camel(name.capitalize()), resolved_model_import_path, ) # Resolve schema class using unified resolver schema_cls = resolve_schema( schema, resolved_schema_import_path, default_schema=model_cls.Schema, context="blueprint schema", ) res_id_param_name: str = res_id_param or f"{name.lower()}_id" # Step 1: Normalize methods into a dict with configs normalized_methods = self._normalize_methods(methods) # Step 2: Apply skip_methods to remove explicitly disabled methods # This happens after normalization so it works consistently regardless # of whether methods was a list or dict if skip_methods: for method_to_skip in skip_methods: method_enum = CRUDMethod(method_to_skip) normalized_methods.pop(method_enum, None) # Warn if dict already disabled this method (redundant) if isinstance(methods, dict) and methods.get(method_enum) is False: import warnings warnings.warn( f"Method {method_enum.value} is set to False in 'methods' dict " f"and also appears in 'skip_methods'. The skip_methods entry is redundant.", UserWarning, stacklevel=3, ) return CRUDConfig( name=name, url_prefix=resolved_url_prefix, import_name=import_name, model_cls=model_cls, model_name=model_cls.__name__, schema_cls=schema_cls, schema_name=schema_cls.__name__, schema_import_path=resolved_schema_import_path, model_import_path=resolved_model_import_path, res_id_name=res_id, res_id_param_name=res_id_param_name, methods=normalized_methods, ) def _normalize_methods( self, methods_raw: list[CRUDMethod] | MethodConfigMapping, ) -> dict[CRUDMethod, MethodConfig]: """Normalize different method inputs into a standard dict. Behavior: - If methods_raw is a list: Only those methods are enabled (explicit whitelist) - If methods_raw is a dict: All methods are enabled by default, unless: - Explicitly set to False in the dict - Will be removed later by skip_methods Args: methods_raw: Either a list of methods to enable, or a dict mapping methods to their configuration Returns: Normalized dict mapping enabled methods to their configuration Raises: TypeError: If methods_raw is not a list or dict, or if dict values are invalid """ normalized: dict[CRUDMethod, MethodConfig] = {} if isinstance(methods_raw, list): # List mode: explicit whitelist - only these methods are enabled for item in methods_raw: key = CRUDMethod(item) normalized[key] = {} return normalized if not isinstance(methods_raw, Mapping): raise TypeError( f"CRUDBlueprint 'methods' argument must be a list or a dict, got {type(methods_raw).__name__}" ) # Dict mode: all methods enabled by default, process overrides # First, enable all methods with default config for method in CRUDMethod: normalized[method] = {} # Then apply dict overrides for method, config in methods_raw.items(): key = CRUDMethod(method) if config is True: # Explicitly enabled with defaults (redundant but allowed) normalized[key] = {} elif config is False: # Explicitly disabled - remove from normalized dict normalized.pop(key, None) elif isinstance(config, dict): # Custom configuration provided normalized[key] = config else: raise TypeError( f"CRUDBlueprint method config for {method} must be a dict, True, or False; " f"got {type(config).__name__}" ) return normalized @staticmethod def _resolve_model_class( model_candidate: type[BaseModel] | str, model_import_path: str, ) -> type[BaseModel]: """Resolve a model reference (class or string name) to a BaseModel subclass. Args: model_candidate: Model class or string name to resolve model_import_path: Module path to import string models from Returns: Resolved BaseModel subclass Raises: ValueError: If model cannot be imported or resolved """ if isinstance(model_candidate, str): try: model_module = import_module(model_import_path) model_cls = getattr(model_module, model_candidate) except ImportError as e: raise ValueError(f"Could not import module '{model_import_path}' for model '{model_candidate}'.") from e except AttributeError as e: raise ValueError(f"Could not find model '{model_candidate}' in '{model_import_path}'.") from e # Validate imported model is actually a BaseModel subclass if not isinstance(model_cls, type) or not issubclass(model_cls, BaseModel): raise ValueError( f"Imported '{model_candidate}' from '{model_import_path}' " f"is not a BaseModel subclass, got {type(model_cls).__name__}." ) return model_cls if isinstance(model_candidate, type) and issubclass(model_candidate, BaseModel): return model_candidate raise ValueError( f"CRUDBlueprint 'model' argument must be a string or a BaseModel subclass, " f"got {type(model_candidate).__name__}." ) def _resolve_schema_class( self, schema_candidate: type[Schema] | Schema | str, *, config: CRUDConfig, method: CRUDMethod, ) -> type[Schema]: """Resolve a schema reference for a specific CRUD method. This is a convenience wrapper around the module-level resolve_schema function, providing method-specific context for error messages. Args: schema_candidate: Schema to resolve config: CRUD configuration with import paths method: The CRUD method this schema is for Returns: Resolved Schema class """ return resolve_schema( schema_candidate, config.schema_import_path, context=f"{method.value} method", ) def _prepare_update_schema( self, config: CRUDConfig ) -> Schema | type[Schema] | SQLAlchemySchema | type[SQLAlchemySchema]: """Create update schema for PATCH operations. If an explicit arg_schema is provided in PATCH method config, it's used. Otherwise, creates a partial version of the default schema. Args: config: Configuration object Returns: Update schema instance or class """ patch_config = config.methods.get(CRUDMethod.PATCH, {}) update_schema_arg = patch_config.get("arg_schema") if update_schema_arg is not None: # Explicit patch schema provided - use unified resolver return resolve_schema( update_schema_arg, config.schema_import_path, context="PATCH arg_schema", ) # Create partial schema from default # NOTE: the following will trigger a warning in apispec if no custom resolver is set update_schema = config.schema_cls(partial=True) if isinstance(update_schema, SQLAlchemySchema): update_schema._load_instance = False return update_schema def _register_crud_routes( self, config: CRUDConfig, update_schema: Schema | type[Schema], ) -> None: """Register all CRUD routes for the blueprint. Args: config: Configuration object update_schema: Update schema for PATCH operations """ id_type = str(getattr(config.model_cls, config.res_id_name).type).lower() if id_type.startswith("char"): id_type = "uuid" model_cls: type[BaseModel] = config.model_cls schema_cls: type[Schema] = config.schema_cls if CRUDMethod.INDEX in config.methods or CRUDMethod.POST in config.methods: # Initialize variables to avoid "possibly unbound" errors index_schema_class: type[Schema] | None = None query_filter_schema: type[Schema] | None = None if CRUDMethod.INDEX in config.methods: index_schema_candidate = config.methods[CRUDMethod.INDEX].get("schema", schema_cls) index_schema_class = self._resolve_schema_class( index_schema_candidate, config=config, method=CRUDMethod.INDEX ) query_filter_schema = generate_filter_schema(base_schema=index_schema_class) class GenericIndex(MethodView): """Index/Post endpoints.""" if CRUDMethod.INDEX in config.methods: if self._default_page_size is not None: @self.arguments(query_filter_schema, location="query", unknown=RAISE) # pyright: ignore[reportArgumentType] @self.response(HTTPStatus.OK, index_schema_class(many=True)) # type: ignore[misc] # pyright: ignore[reportArgumentType, reportOptionalCall] @self.paginate(page_size=self._default_page_size) @self.doc(operationId=f"list{config.model_name}") def get( _self, # NOTE: using _self to avoid collision with outer self filters: dict, pagination_parameters: "PaginationParameters", **kwargs: Any, ) -> Sequence[BaseModel]: """Fetch all resources. kwargs might contains path parameters to filter by (eg /user/<uuid:user_id>/roles/) """ stmts = get_statements_from_filters(filters, model=model_cls) base_query = sa.select(model_cls).filter_by(**kwargs).filter(*stmts) count_query = sa.select(sa.func.count()).select_from(base_query.subquery()) total_items = self._db_session.scalar(count_query) pagination_parameters.item_count = total_items # pyright: ignore[reportAttributeAccessIssue] if pagination_parameters.page_size > 0: paginated_query = base_query.limit(pagination_parameters.page_size).offset( pagination_parameters.page_size * (pagination_parameters.page - 1) ) else: paginated_query = base_query return self._db_session.execute(paginated_query).scalars().all() else: @self.arguments(query_filter_schema, location="query", unknown=RAISE) # pyright: ignore[reportArgumentType] @self.response(HTTPStatus.OK, index_schema_class(many=True)) # type: ignore[misc] # pyright: ignore[reportArgumentType, reportOptionalCall] @self.doc(operationId=f"list{config.model_name}") def get( _self, filters: dict, **kwargs: Any, ) -> Sequence[BaseModel]: """Fetch all resources (no pagination).""" stmts = get_statements_from_filters(filters, model=model_cls) base_query = sa.select(model_cls).filter_by(**kwargs).filter(*stmts) return self._db_session.execute(base_query).scalars().all() if CRUDMethod.POST in config.methods: @self.arguments(config.methods[CRUDMethod.POST].get("schema", schema_cls)) @self.response( HTTPStatus.OK, config.methods[CRUDMethod.POST].get("schema", schema_cls), ) @self.doc( responses={ HTTPStatus.NOT_FOUND: {"description": f"{config.name} resource not found"}, HTTPStatus.CONFLICT: {"description": "DB error."}, }, operationId=f"create{config.model_name}", ) def post( _self, new_object: BaseModel, **kwargs: str | int | float | bool | bytes | None, ) -> BaseModel: """Create and return new resource.""" new_object.update(commit=True, **kwargs) return new_object self._configure_endpoint( GenericIndex, "get", f"Fetch all {config.name} resources.", config.methods.get(CRUDMethod.INDEX, {}), ) self._configure_endpoint( GenericIndex, "post", f"Create and return new {config.name}.", config.methods.get(CRUDMethod.POST, {}), ) self.route("")(GenericIndex) class GenericCRUD(MethodView): """Resource-specific endpoints.""" if CRUDMethod.GET in config.methods: @self.doc( responses={HTTPStatus.NOT_FOUND: {"description": f"{config.name} not found"}}, operationId=f"get{config.model_name}", ) @self.response( HTTPStatus.OK, config.methods[CRUDMethod.GET].get("schema", schema_cls), ) def get(_self, **kwargs: Any) -> BaseModel: """Fetch resource by ID.""" kwargs[config.res_id_name] = kwargs.pop(config.res_id_param_name) return model_cls.get_by_or_404(**kwargs) if CRUDMethod.PATCH in config.methods: @self.arguments(update_schema) @self.doc( responses={ HTTPStatus.NOT_FOUND: {"description": f"{config.name} not found"}, HTTPStatus.CONFLICT: {"description": "DB error."}, }, operationId=f"update{config.model_name}", ) @self.response( HTTPStatus.OK, config.methods[CRUDMethod.PATCH].get("schema", schema_cls), ) def patch(_self, payload: dict, **kwargs: str | int | uuid.UUID | bool | None) -> BaseModel: """Update resource.""" kwargs[config.res_id_name] = kwargs.pop(config.res_id_param_name) res = model_cls.get_by_or_404(**kwargs) res.update(**payload) return res if CRUDMethod.DELETE in config.methods: @self.response(HTTPStatus.NO_CONTENT, description=f"{config.name} deleted") @self.doc(operationId=f"delete{config.model_name}") def delete(_self, **kwargs: str | int | uuid.UUID | bool | None) -> tuple[str, int]: """Delete resource.""" kwargs[config.res_id_name] = kwargs.pop(config.res_id_param_name) res = model_cls.get_by_or_404(**kwargs) res.delete() return "", HTTPStatus.NO_CONTENT if "PUT" in config.methods: raise NotImplementedError("PUT method is not implemented. Use PATCH instead.") self._configure_endpoint( GenericCRUD, "get", f"Fetch {config.name} by ID.", config.methods.get(CRUDMethod.GET, {}), ) self._configure_endpoint( GenericCRUD, "patch", f"Update {config.name} by ID.", config.methods.get(CRUDMethod.PATCH, {}), ) self._configure_endpoint( GenericCRUD, "delete", f"Delete {config.name} by ID.", config.methods.get(CRUDMethod.DELETE, {}), ) # Only register GenericCRUD if it has at least one method if any(method in config.methods for method in [CRUDMethod.GET, CRUDMethod.PATCH, CRUDMethod.DELETE]): self.route(f"<{id_type}:{config.res_id_param_name}>")(GenericCRUD) def _configure_endpoint( self, view_cls: type[MethodView], method_name: str, docstring: str, method_config: MethodConfig, ) -> None: """Configure endpoint with docstring and admin decorator if needed. Args: view_cls: MethodView class containing the endpoint method_name: Name of the method to configure docstring: Docstring to set on the method method_config: Configuration dict for the method """ if hasattr(view_cls, method_name): method = getattr(view_cls, method_name) method.__doc__ = docstring if method_config.get("admin_only", False): from ..perms import PermsBlueprintMixin if isinstance(self, PermsBlueprintMixin): self.admin_endpoint(method) else: raise TypeError("Blueprint must inherit from PermsBlueprintMixin to set admin_only endpoint.")