diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index d4860d5..332f3ca 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1,2 +1 @@ -# These are supported funding model platforms -custom: https://www.buymeacoffee.com/frankie567 +github: frankie567 diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 508f620..d819c8f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -3,34 +3,74 @@ name: Build on: [push, pull_request] jobs: - test: runs-on: ubuntu-latest + + services: + postgres: + image: postgres:alpine + ports: + - 5432:5432 + env: + POSTGRES_USER: fastapiusers + POSTGRES_PASSWORD: fastapiuserspassword + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + mariadb: + image: mariadb + ports: + - 3306:3306 + env: + MARIADB_ROOT_PASSWORD: fastapiuserspassword + MARIADB_DATABASE: fastapiusers + MARIADB_USER: fastapiusers + MARIADB_PASSWORD: fastapiuserspassword + strategy: + fail-fast: false matrix: - python_version: [3.7, 3.8, 3.9, '3.10'] + python_version: [3.9, '3.10', '3.11', '3.12', '3.13'] + database_url: + [ + "sqlite+aiosqlite:///./test-fastapiusers.db", + "postgresql+asyncpg://fastapiusers:fastapiuserspassword@localhost:5432/fastapiusers", + "mysql+aiomysql://root:fastapiuserspassword@localhost:3306/fastapiusers", + ] steps: - - uses: actions/checkout@v1 - - name: Set up Python - uses: actions/setup-python@v1 - with: - python-version: ${{ matrix.python_version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements.dev.txt - - name: Test with pytest - env: - CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} - run: | - pytest --cov=fastapi_users_db_sqlalchemy/ - codecov - - name: Build and install it on system host - run: | - flit build - flit install --python $(which python) - python test_build.py + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python_version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install hatch + hatch env create + - name: Lint and typecheck + run: | + hatch run lint-check + - name: Test + env: + DATABASE_URL: ${{ matrix.database_url }} + run: | + hatch run test-cov-xml + - uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: true + verbose: true + - name: Build and install it on system host + run: | + hatch build + pip install dist/fastapi_users_db_sqlalchemy-*.whl + python test_build.py release: runs-on: ubuntu-latest @@ -38,18 +78,26 @@ jobs: if: startsWith(github.ref, 'refs/tags/') steps: - - uses: actions/checkout@v1 - - name: Set up Python - uses: actions/setup-python@v1 - with: - python-version: 3.7 - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements.dev.txt - - name: Release on PyPI - env: - FLIT_USERNAME: ${{ secrets.FLIT_USERNAME }} - FLIT_PASSWORD: ${{ secrets.FLIT_PASSWORD }} - run: | - flit publish + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: 3.9 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install hatch + - name: Build and publish on PyPI + env: + HATCH_INDEX_USER: ${{ secrets.HATCH_INDEX_USER }} + HATCH_INDEX_AUTH: ${{ secrets.HATCH_INDEX_AUTH }} + run: | + hatch build + hatch publish + - name: Create release + uses: ncipollo/release-action@v1 + with: + draft: true + body: ${{ github.event.head_commit.message }} + artifacts: dist/*.whl,dist/*.tar.gz + token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore index b949f48..434348e 100644 --- a/.gitignore +++ b/.gitignore @@ -104,9 +104,6 @@ ENV/ # mypy .mypy_cache/ -# .vscode -.vscode/ - # OS files .DS_Store diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..5d8d955 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,21 @@ +{ + "python.analysis.typeCheckingMode": "basic", + "python.analysis.autoImportCompletions": true, + "python.terminal.activateEnvironment": true, + "python.terminal.activateEnvInCurrentTerminal": true, + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true, + "editor.rulers": [88], + "python.defaultInterpreterPath": "${workspaceFolder}/.hatch/fastapi-users-db-sqlalchemy/bin/python", + "python.testing.pytestPath": "${workspaceFolder}/.hatch/fastapi-users-db-sqlalchemy/bin/pytest", + "python.testing.cwd": "${workspaceFolder}", + "python.testing.pytestArgs": ["--no-cov"], + "[python]": { + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.fixAll": "explicit", + "source.organizeImports": "explicit" + }, + "editor.defaultFormatter": "charliermarsh.ruff" + } + } diff --git a/README.md b/README.md index 2a9f4e1..2f8498c 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ [](https://badge.fury.io/py/fastapi-users-db-sqlalchemy) [](https://pepy.tech/project/fastapi-users-db-sqlalchemy)
--- @@ -32,40 +32,14 @@ Add quickly a registration and authentication system to your [FastAPI](https://f ### Setup environment -You should create a virtual environment and activate it: - -```bash -python -m venv venv/ -``` - -```bash -source venv/bin/activate -``` - -And then install the development dependencies: - -```bash -pip install -r requirements.dev.txt -``` +We use [Hatch](https://hatch.pypa.io/latest/install/) to manage the development environment and production build. Ensure it's installed on your system. ### Run unit tests You can run all the tests with: ```bash -make test -``` - -Alternatively, you can run `pytest` yourself: - -```bash -pytest -``` - -There are quite a few unit tests, so you might run into ulimit issues where there are too many open file descriptors. You may be able to set a new, higher limit temporarily with: - -```bash -ulimit -n 2048 +hatch run test ``` ### Format the code @@ -73,7 +47,7 @@ ulimit -n 2048 Execute the following command to apply `isort` and `black` formatting: ```bash -make format +hatch run lint ``` ## License diff --git a/fastapi_users_db_sqlalchemy/__init__.py b/fastapi_users_db_sqlalchemy/__init__.py index 6889aab..467a2bf 100644 --- a/fastapi_users_db_sqlalchemy/__init__.py +++ b/fastapi_users_db_sqlalchemy/__init__.py @@ -1,152 +1,190 @@ """FastAPI Users database adapter for SQLAlchemy.""" -from typing import Optional, Type + +import uuid +from typing import TYPE_CHECKING, Any, Generic, Optional from fastapi_users.db.base import BaseUserDatabase -from fastapi_users.models import UD -from pydantic import UUID4 -from sqlalchemy import ( - Boolean, - Column, - ForeignKey, - Integer, - String, - delete, - func, - select, - update, -) +from fastapi_users.models import ID, OAP, UP +from sqlalchemy import Boolean, ForeignKey, Integer, String, func, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.ext.declarative import declared_attr -from sqlalchemy.orm import joinedload +from sqlalchemy.orm import Mapped, declared_attr, mapped_column from sqlalchemy.sql import Select -from fastapi_users_db_sqlalchemy.guid import GUID +from fastapi_users_db_sqlalchemy.generics import GUID + +__version__ = "7.0.0" -__version__ = "3.0.1" +UUID_ID = uuid.UUID -class SQLAlchemyBaseUserTable: +class SQLAlchemyBaseUserTable(Generic[ID]): """Base SQLAlchemy users table definition.""" __tablename__ = "user" - id = Column(GUID, primary_key=True) - email = Column(String(length=320), unique=True, index=True, nullable=False) - hashed_password = Column(String(length=1024), nullable=False) - is_active = Column(Boolean, default=True, nullable=False) - is_superuser = Column(Boolean, default=False, nullable=False) - is_verified = Column(Boolean, default=False, nullable=False) + if TYPE_CHECKING: # pragma: no cover + id: ID + email: str + hashed_password: str + is_active: bool + is_superuser: bool + is_verified: bool + else: + email: Mapped[str] = mapped_column( + String(length=320), unique=True, index=True, nullable=False + ) + hashed_password: Mapped[str] = mapped_column( + String(length=1024), nullable=False + ) + is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + is_superuser: Mapped[bool] = mapped_column( + Boolean, default=False, nullable=False + ) + is_verified: Mapped[bool] = mapped_column( + Boolean, default=False, nullable=False + ) + + +class SQLAlchemyBaseUserTableUUID(SQLAlchemyBaseUserTable[UUID_ID]): + if TYPE_CHECKING: # pragma: no cover + id: UUID_ID + else: + id: Mapped[UUID_ID] = mapped_column(GUID, primary_key=True, default=uuid.uuid4) -class SQLAlchemyBaseOAuthAccountTable: +class SQLAlchemyBaseOAuthAccountTable(Generic[ID]): """Base SQLAlchemy OAuth account table definition.""" __tablename__ = "oauth_account" - id = Column(GUID, primary_key=True) - oauth_name = Column(String(length=100), index=True, nullable=False) - access_token = Column(String(length=1024), nullable=False) - expires_at = Column(Integer, nullable=True) - refresh_token = Column(String(length=1024), nullable=True) - account_id = Column(String(length=320), index=True, nullable=False) - account_email = Column(String(length=320), nullable=False) + if TYPE_CHECKING: # pragma: no cover + id: ID + oauth_name: str + access_token: str + expires_at: Optional[int] + refresh_token: Optional[str] + account_id: str + account_email: str + else: + oauth_name: Mapped[str] = mapped_column( + String(length=100), index=True, nullable=False + ) + access_token: Mapped[str] = mapped_column(String(length=1024), nullable=False) + expires_at: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + refresh_token: Mapped[Optional[str]] = mapped_column( + String(length=1024), nullable=True + ) + account_id: Mapped[str] = mapped_column( + String(length=320), index=True, nullable=False + ) + account_email: Mapped[str] = mapped_column(String(length=320), nullable=False) + - @declared_attr - def user_id(cls): - return Column(GUID, ForeignKey("user.id", ondelete="cascade"), nullable=False) +class SQLAlchemyBaseOAuthAccountTableUUID(SQLAlchemyBaseOAuthAccountTable[UUID_ID]): + if TYPE_CHECKING: # pragma: no cover + id: UUID_ID + user_id: UUID_ID + else: + id: Mapped[UUID_ID] = mapped_column(GUID, primary_key=True, default=uuid.uuid4) + + @declared_attr + def user_id(cls) -> Mapped[GUID]: + return mapped_column( + GUID, ForeignKey("user.id", ondelete="cascade"), nullable=False + ) -class SQLAlchemyUserDatabase(BaseUserDatabase[UD]): +class SQLAlchemyUserDatabase(Generic[UP, ID], BaseUserDatabase[UP, ID]): """ Database adapter for SQLAlchemy. - :param user_db_model: Pydantic model of a DB representation of a user. :param session: SQLAlchemy session instance. :param user_table: SQLAlchemy user model. :param oauth_account_table: Optional SQLAlchemy OAuth accounts model. """ session: AsyncSession - user_table: Type[SQLAlchemyBaseUserTable] - oauth_account_table: Optional[Type[SQLAlchemyBaseOAuthAccountTable]] + user_table: type[UP] + oauth_account_table: Optional[type[SQLAlchemyBaseOAuthAccountTable]] def __init__( self, - user_db_model: Type[UD], session: AsyncSession, - user_table: Type[SQLAlchemyBaseUserTable], - oauth_account_table: Optional[Type[SQLAlchemyBaseOAuthAccountTable]] = None, + user_table: type[UP], + oauth_account_table: Optional[type[SQLAlchemyBaseOAuthAccountTable]] = None, ): - super().__init__(user_db_model) self.session = session self.user_table = user_table self.oauth_account_table = oauth_account_table - async def get(self, id: UUID4) -> Optional[UD]: + async def get(self, id: ID) -> Optional[UP]: statement = select(self.user_table).where(self.user_table.id == id) return await self._get_user(statement) - async def get_by_email(self, email: str) -> Optional[UD]: + async def get_by_email(self, email: str) -> Optional[UP]: statement = select(self.user_table).where( func.lower(self.user_table.email) == func.lower(email) ) return await self._get_user(statement) - async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]: - if self.oauth_account_table is not None: - statement = ( - select(self.user_table) - .join(self.oauth_account_table) - .where(self.oauth_account_table.oauth_name == oauth) - .where(self.oauth_account_table.account_id == account_id) - ) - return await self._get_user(statement) + async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UP]: + if self.oauth_account_table is None: + raise NotImplementedError() + + statement = ( + select(self.user_table) + .join(self.oauth_account_table) + .where(self.oauth_account_table.oauth_name == oauth) # type: ignore + .where(self.oauth_account_table.account_id == account_id) # type: ignore + ) + return await self._get_user(statement) - async def create(self, user: UD) -> UD: - user_table = self.user_table(**user.dict(exclude={"oauth_accounts"})) - self.session.add(user_table) + async def create(self, create_dict: dict[str, Any]) -> UP: + user = self.user_table(**create_dict) + self.session.add(user) + await self.session.commit() + await self.session.refresh(user) + return user - if self.oauth_account_table is not None: - for oauth_account in user.oauth_accounts: - oauth_account_table = self.oauth_account_table( - **oauth_account.dict(), user_id=user.id - ) - self.session.add(oauth_account_table) + async def update(self, user: UP, update_dict: dict[str, Any]) -> UP: + for key, value in update_dict.items(): + setattr(user, key, value) + self.session.add(user) + await self.session.commit() + await self.session.refresh(user) + return user + async def delete(self, user: UP) -> None: + await self.session.delete(user) await self.session.commit() - return await self.get(user.id) - - async def update(self, user: UD) -> UD: - user_table = await self.session.get(self.user_table, user.id) - for key, value in user.dict(exclude={"oauth_accounts"}).items(): - setattr(user_table, key, value) - self.session.add(user_table) - - if self.oauth_account_table is not None: - for oauth_account in user.oauth_accounts: - statement = update( - self.oauth_account_table, - whereclause=self.oauth_account_table.id == oauth_account.id, - values={**oauth_account.dict(), "user_id": user.id}, - ) - await self.session.execute(statement) + + async def add_oauth_account(self, user: UP, create_dict: dict[str, Any]) -> UP: + if self.oauth_account_table is None: + raise NotImplementedError() + + await self.session.refresh(user) + oauth_account = self.oauth_account_table(**create_dict) + self.session.add(oauth_account) + user.oauth_accounts.append(oauth_account) # type: ignore + self.session.add(user) await self.session.commit() - return await self.get(user.id) + return user + + async def update_oauth_account( + self, user: UP, oauth_account: OAP, update_dict: dict[str, Any] + ) -> UP: + if self.oauth_account_table is None: + raise NotImplementedError() - async def delete(self, user: UD) -> None: - statement = delete(self.user_table, self.user_table.id == user.id) - await self.session.execute(statement) + for key, value in update_dict.items(): + setattr(oauth_account, key, value) + self.session.add(oauth_account) await self.session.commit() - async def _get_user(self, statement: Select) -> Optional[UD]: - if self.oauth_account_table is not None: - statement = statement.options(joinedload("oauth_accounts")) + return user + async def _get_user(self, statement: Select) -> Optional[UP]: results = await self.session.execute(statement) - user = results.first() - if user is None: - return None - - return self.user_db_model.from_orm(user[0]) + return results.unique().scalar_one_or_none() diff --git a/fastapi_users_db_sqlalchemy/access_token.py b/fastapi_users_db_sqlalchemy/access_token.py index 7c3de7e..9f68af6 100644 --- a/fastapi_users_db_sqlalchemy/access_token.py +++ b/fastapi_users_db_sqlalchemy/access_token.py @@ -1,80 +1,89 @@ +import uuid from datetime import datetime -from typing import Generic, Optional, Type +from typing import TYPE_CHECKING, Any, Generic, Optional -from fastapi_users.authentication.strategy.db import A, AccessTokenDatabase -from sqlalchemy import Column, DateTime, ForeignKey, String, delete, select, update +from fastapi_users.authentication.strategy.db import AP, AccessTokenDatabase +from fastapi_users.models import ID +from sqlalchemy import ForeignKey, String, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.ext.declarative import declared_attr +from sqlalchemy.orm import Mapped, declared_attr, mapped_column -from fastapi_users_db_sqlalchemy.guid import GUID +from fastapi_users_db_sqlalchemy.generics import GUID, TIMESTAMPAware, now_utc -class SQLAlchemyBaseAccessTokenTable: +class SQLAlchemyBaseAccessTokenTable(Generic[ID]): """Base SQLAlchemy access token table definition.""" __tablename__ = "accesstoken" - token = Column(String(length=43), primary_key=True) - created_at = Column(DateTime(timezone=True), index=True, nullable=False) + if TYPE_CHECKING: # pragma: no cover + token: str + created_at: datetime + user_id: ID + else: + token: Mapped[str] = mapped_column(String(length=43), primary_key=True) + created_at: Mapped[datetime] = mapped_column( + TIMESTAMPAware(timezone=True), index=True, nullable=False, default=now_utc + ) + + +class SQLAlchemyBaseAccessTokenTableUUID(SQLAlchemyBaseAccessTokenTable[uuid.UUID]): + if TYPE_CHECKING: # pragma: no cover + user_id: uuid.UUID + else: - @declared_attr - def user_id(cls): - return Column(GUID, ForeignKey("user.id", ondelete="cascade"), nullable=False) + @declared_attr + def user_id(cls) -> Mapped[GUID]: + return mapped_column( + GUID, ForeignKey("user.id", ondelete="cascade"), nullable=False + ) -class SQLAlchemyAccessTokenDatabase(AccessTokenDatabase, Generic[A]): +class SQLAlchemyAccessTokenDatabase(Generic[AP], AccessTokenDatabase[AP]): """ Access token database adapter for SQLAlchemy. - :param access_token_model: Pydantic model of a DB representation of an access token. :param session: SQLAlchemy session instance. :param access_token_table: SQLAlchemy access token model. """ def __init__( self, - access_token_model: Type[A], session: AsyncSession, - access_token_table: Type[SQLAlchemyBaseAccessTokenTable], + access_token_table: type[AP], ): - self.access_token_model = access_token_model self.session = session self.access_token_table = access_token_table async def get_by_token( self, token: str, max_age: Optional[datetime] = None - ) -> Optional[A]: + ) -> Optional[AP]: statement = select(self.access_token_table).where( - self.access_token_table.token == token + self.access_token_table.token == token # type: ignore ) if max_age is not None: - statement = statement.where(self.access_token_table.created_at >= max_age) + statement = statement.where( + self.access_token_table.created_at >= max_age # type: ignore + ) results = await self.session.execute(statement) - access_token = results.first() - if access_token is None: - return None - return self.access_token_model.from_orm(access_token[0]) - - async def create(self, access_token: A) -> A: - access_token_db = self.access_token_table(**access_token.dict()) - self.session.add(access_token_db) + return results.scalar_one_or_none() + + async def create(self, create_dict: dict[str, Any]) -> AP: + access_token = self.access_token_table(**create_dict) + self.session.add(access_token) await self.session.commit() + await self.session.refresh(access_token) return access_token - async def update(self, access_token: A) -> A: - statement = ( - update(self.access_token_table) - .where(self.access_token_table.token == access_token.token) - .values(access_token.dict()) - ) - await self.session.execute(statement) + async def update(self, access_token: AP, update_dict: dict[str, Any]) -> AP: + for key, value in update_dict.items(): + setattr(access_token, key, value) + self.session.add(access_token) await self.session.commit() + await self.session.refresh(access_token) return access_token - async def delete(self, access_token: A) -> None: - statement = delete( - self.access_token_table, self.access_token_table.token == access_token.token - ) - await self.session.execute(statement) + async def delete(self, access_token: AP) -> None: + await self.session.delete(access_token) await self.session.commit() diff --git a/fastapi_users_db_sqlalchemy/guid.py b/fastapi_users_db_sqlalchemy/generics.py similarity index 57% rename from fastapi_users_db_sqlalchemy/guid.py rename to fastapi_users_db_sqlalchemy/generics.py index 6e453ae..ddfe639 100644 --- a/fastapi_users_db_sqlalchemy/guid.py +++ b/fastapi_users_db_sqlalchemy/generics.py @@ -1,19 +1,22 @@ import uuid +from datetime import datetime, timezone +from typing import Optional from pydantic import UUID4 +from sqlalchemy import CHAR, TIMESTAMP, TypeDecorator from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.types import CHAR, TypeDecorator class GUID(TypeDecorator): # pragma: no cover - """Platform-independent GUID type. + """ + Platform-independent GUID type. Uses PostgreSQL's UUID type, otherwise uses CHAR(36), storing as regular strings. """ class UUIDChar(CHAR): - python_type = UUID4 + python_type = UUID4 # type: ignore impl = UUIDChar cache_ok = True @@ -42,3 +45,24 @@ def process_result_value(self, value, dialect): if not isinstance(value, uuid.UUID): value = uuid.UUID(value) return value + + +def now_utc(): + return datetime.now(timezone.utc) + + +class TIMESTAMPAware(TypeDecorator): # pragma: no cover + """ + MySQL and SQLite will always return naive-Python datetimes. + + We store everything as UTC, but we want to have + only offset-aware Python datetimes, even with MySQL and SQLite. + """ + + impl = TIMESTAMP + cache_ok = True + + def process_result_value(self, value: Optional[datetime], dialect): + if value is not None and dialect.name != "postgresql": + return value.replace(tzinfo=timezone.utc) + return value diff --git a/pyproject.toml b/pyproject.toml index 0a614e0..ef32d8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,38 +1,94 @@ -[tool.isort] -profile = "black" +[tool.mypy] +plugins = "sqlalchemy.ext.mypy.plugin" [tool.pytest.ini_options] -asyncio_mode = "auto" +asyncio_mode = "strict" +asyncio_default_fixture_loop_scope = "function" addopts = "--ignore=test_build.py" -markers = ["db"] + +[tool.ruff] + +[tool.ruff.lint] +extend-select = ["I", "UP"] + +[tool.hatch] + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.version] +source = "regex_commit" +commit_extra_args = ["-e"] +path = "fastapi_users_db_sqlalchemy/__init__.py" + +[tool.hatch.envs.default] +installer = "uv" +dependencies = [ + "aiosqlite", + "asyncpg", + "aiomysql", + "pytest", + "pytest-asyncio", + "black", + "mypy", + "pytest-cov", + "pytest-mock", + "asynctest", + "httpx", + "asgi_lifespan", + "ruff", + "sqlalchemy[asyncio,mypy]", +] + +[tool.hatch.envs.default.scripts] +test = "pytest --cov=fastapi_users_db_sqlalchemy/ --cov-report=term-missing --cov-fail-under=100" +test-cov-xml = "pytest --cov=fastapi_users_db_sqlalchemy/ --cov-report=xml --cov-fail-under=100" +lint = [ + "ruff format . ", + "ruff check --fix .", + "mypy fastapi_users_db_sqlalchemy/", +] +lint-check = [ + "ruff format --check .", + "ruff check .", + "mypy fastapi_users_db_sqlalchemy/", +] + +[tool.hatch.build.targets.sdist] +support-legacy = true # Create setup.py [build-system] -requires = ["flit_core >=2,<3"] -build-backend = "flit_core.buildapi" - -[tool.flit.metadata] -module = "fastapi_users_db_sqlalchemy" -dist-name = "fastapi-users-db-sqlalchemy" -author = "François Voron" -author-email = "fvoron@gmail.com" -home-page = "https://github.com/fastapi-users/fastapi-users-db-sqlalchemy" +requires = ["hatchling", "hatch-regex-commit"] +build-backend = "hatchling.build" + +[project] +name = "fastapi-users-db-sqlalchemy" +authors = [ + { name = "François Voron", email = "fvoron@gmail.com" }, +] +description = "FastAPI Users database adapter for SQLAlchemy" +readme = "README.md" +dynamic = ["version"] classifiers = [ "License :: OSI Approved :: MIT License", "Development Status :: 5 - Production/Stable", + "Framework :: FastAPI", "Framework :: AsyncIO", "Intended Audience :: Developers", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3 :: Only", "Topic :: Internet :: WWW/HTTP :: Session", ] -description-file = "README.md" -requires-python = ">=3.7" -requires = [ - "fastapi-users >= 9.1.0", - "sqlalchemy[asyncio] >=1.4", +requires-python = ">=3.9" +dependencies = [ + "fastapi-users >= 10.0.0", + "sqlalchemy[asyncio] >=2.0.0,<2.1.0", ] -[tool.flit.metadata.urls] +[project.urls] Documentation = "https://fastapi-users.github.io/fastapi-users" +Source = "https://github.com/fastapi-users/fastapi-users-db-sqlalchemy" diff --git a/requirements.dev.txt b/requirements.dev.txt deleted file mode 100644 index 1eca1c2..0000000 --- a/requirements.dev.txt +++ /dev/null @@ -1,20 +0,0 @@ --r requirements.txt - -aiosqlite -flake8 -pytest -requests -isort -pytest-asyncio -flake8-docstrings -greenlet -black -mypy -codecov -pytest-cov -pytest-mock -asynctest -flit -bumpversion -httpx -asgi_lifespan diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 626054b..0000000 --- a/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -fastapi-users >= 9.1.0 -sqlalchemy[asyncio,mypy] >=1.4 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 1bc99a2..0000000 --- a/setup.cfg +++ /dev/null @@ -1,14 +0,0 @@ -[bumpversion] -current_version = 3.0.1 -commit = True -tag = True - -[bumpversion:file:fastapi_users_db_sqlalchemy/__init__.py] -search = __version__ = "{current_version}" -replace = __version__ = "{new_version}" - -[flake8] -exclude = docs -max-line-length = 88 -docstring-convention = numpy -ignore = D1 diff --git a/tests/conftest.py b/tests/conftest.py index b9bbfcd..b7661c6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,59 +1,47 @@ -import asyncio -from typing import Optional +import os +from typing import Any, Optional import pytest -from fastapi_users import models +from fastapi_users import schemas - -class User(models.BaseUser): - first_name: Optional[str] +DATABASE_URL = os.getenv( + "DATABASE_URL", "sqlite+aiosqlite:///./test-sqlalchemy-user.db" +) -class UserCreate(models.BaseUserCreate): +class User(schemas.BaseUser): first_name: Optional[str] -class UserUpdate(models.BaseUserUpdate): - pass - - -class UserDB(User, models.BaseUserDB): - pass +class UserCreate(schemas.BaseUserCreate): + first_name: Optional[str] -class UserOAuth(User, models.BaseOAuthAccountMixin): +class UserUpdate(schemas.BaseUserUpdate): pass -class UserDBOAuth(UserOAuth, UserDB): +class UserOAuth(User, schemas.BaseOAuthAccountMixin): pass -@pytest.fixture(scope="session") -def event_loop(): - """Force the pytest-asyncio loop to be the main one.""" - loop = asyncio.new_event_loop() - yield loop - loop.close() - - @pytest.fixture -def oauth_account1() -> models.BaseOAuthAccount: - return models.BaseOAuthAccount( - oauth_name="service1", - access_token="TOKEN", - expires_at=1579000751, - account_id="user_oauth1", - account_email="king.arthur@camelot.bt", - ) +def oauth_account1() -> dict[str, Any]: + return { + "oauth_name": "service1", + "access_token": "TOKEN", + "expires_at": 1579000751, + "account_id": "user_oauth1", + "account_email": "king.arthur@camelot.bt", + } @pytest.fixture -def oauth_account2() -> models.BaseOAuthAccount: - return models.BaseOAuthAccount( - oauth_name="service2", - access_token="TOKEN", - expires_at=1579000751, - account_id="user_oauth2", - account_email="king.arthur@camelot.bt", - ) +def oauth_account2() -> dict[str, Any]: + return { + "oauth_name": "service2", + "access_token": "TOKEN", + "expires_at": 1579000751, + "account_id": "user_oauth2", + "account_email": "king.arthur@camelot.bt", + } diff --git a/tests/test_access_token.py b/tests/test_access_token.py index 34f356d..f0e5fb9 100644 --- a/tests/test_access_token.py +++ b/tests/test_access_token.py @@ -1,28 +1,41 @@ import uuid +from collections.abc import AsyncGenerator from datetime import datetime, timedelta, timezone -from typing import AsyncGenerator import pytest -from fastapi_users.authentication.strategy.db.models import BaseAccessToken +import pytest_asyncio from pydantic import UUID4 from sqlalchemy import exc -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine -from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.orm import DeclarativeBase -from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTable +from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTableUUID from fastapi_users_db_sqlalchemy.access_token import ( SQLAlchemyAccessTokenDatabase, - SQLAlchemyBaseAccessTokenTable, + SQLAlchemyBaseAccessTokenTableUUID, ) +from tests.conftest import DATABASE_URL + + +class Base(DeclarativeBase): + pass -class AccessToken(BaseAccessToken): +class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base): + pass + + +class User(SQLAlchemyBaseUserTableUUID, Base): pass def create_async_session_maker(engine: AsyncEngine): - return sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + return async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) @pytest.fixture @@ -30,71 +43,65 @@ def user_id() -> UUID4: return uuid.uuid4() -@pytest.fixture +@pytest_asyncio.fixture async def sqlalchemy_access_token_db( user_id: UUID4, -) -> AsyncGenerator[SQLAlchemyAccessTokenDatabase, None]: - Base: DeclarativeMeta = declarative_base() - - class AccessTokenTable(SQLAlchemyBaseAccessTokenTable, Base): - pass - - class UserTable(SQLAlchemyBaseUserTable, Base): - pass - - DATABASE_URL = "sqlite+aiosqlite:///./test-sqlalchemy-access-token.db" - engine = create_async_engine( - DATABASE_URL, connect_args={"check_same_thread": False} - ) +) -> AsyncGenerator[SQLAlchemyAccessTokenDatabase[AccessToken], None]: + engine = create_async_engine(DATABASE_URL) sessionmaker = create_async_session_maker(engine) async with engine.begin() as connection: await connection.run_sync(Base.metadata.create_all) - async with sessionmaker() as session: - user = UserTable( - id=user_id, email="lancelot@camelot.bt", hashed_password="guinevere" - ) - session.add(user) - await session.commit() + async with sessionmaker() as session: + user = User( + id=user_id, email="lancelot@camelot.bt", hashed_password="guinevere" + ) + session.add(user) + await session.commit() - yield SQLAlchemyAccessTokenDatabase(AccessToken, session, AccessTokenTable) + yield SQLAlchemyAccessTokenDatabase(session, AccessToken) + async with engine.begin() as connection: await connection.run_sync(Base.metadata.drop_all) @pytest.mark.asyncio -@pytest.mark.db async def test_queries( sqlalchemy_access_token_db: SQLAlchemyAccessTokenDatabase[AccessToken], user_id: UUID4, ): - access_token = AccessToken(token="TOKEN", user_id=user_id) + access_token_create = {"token": "TOKEN", "user_id": user_id} # Create - access_token_db = await sqlalchemy_access_token_db.create(access_token) - assert access_token_db.token == "TOKEN" - assert access_token_db.user_id == user_id + access_token = await sqlalchemy_access_token_db.create(access_token_create) + assert access_token.token == "TOKEN" + assert access_token.user_id == user_id # Update - access_token_db.created_at = datetime.now(timezone.utc) - await sqlalchemy_access_token_db.update(access_token_db) + update_dict = {"created_at": datetime.now(timezone.utc)} + updated_access_token = await sqlalchemy_access_token_db.update( + access_token, update_dict + ) + assert updated_access_token.created_at.replace(microsecond=0) == update_dict[ + "created_at" + ].replace(microsecond=0) # Get by token access_token_by_token = await sqlalchemy_access_token_db.get_by_token( - access_token_db.token + access_token.token ) assert access_token_by_token is not None # Get by token expired access_token_by_token = await sqlalchemy_access_token_db.get_by_token( - access_token_db.token, max_age=datetime.now(timezone.utc) + timedelta(hours=1) + access_token.token, max_age=datetime.now(timezone.utc) + timedelta(hours=1) ) assert access_token_by_token is None # Get by token not expired access_token_by_token = await sqlalchemy_access_token_db.get_by_token( - access_token_db.token, max_age=datetime.now(timezone.utc) - timedelta(hours=1) + access_token.token, max_age=datetime.now(timezone.utc) - timedelta(hours=1) ) assert access_token_by_token is not None @@ -105,23 +112,20 @@ async def test_queries( assert access_token_by_token is None # Delete token - await sqlalchemy_access_token_db.delete(access_token_db) + await sqlalchemy_access_token_db.delete(access_token) deleted_access_token = await sqlalchemy_access_token_db.get_by_token( - access_token_db.token + access_token.token ) assert deleted_access_token is None @pytest.mark.asyncio -@pytest.mark.db async def test_insert_existing_token( sqlalchemy_access_token_db: SQLAlchemyAccessTokenDatabase[AccessToken], user_id: UUID4, ): - access_token = AccessToken(token="TOKEN", user_id=user_id) - await sqlalchemy_access_token_db.create(access_token) + access_token_create = {"token": "TOKEN", "user_id": user_id} + await sqlalchemy_access_token_db.create(access_token_create) with pytest.raises(exc.IntegrityError): - await sqlalchemy_access_token_db.create( - AccessToken(token="TOKEN", user_id=user_id) - ) + await sqlalchemy_access_token_db.create(access_token_create) diff --git a/tests/test_users.py b/tests/test_users.py index fe49e3c..4a83e39 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -1,107 +1,120 @@ -from typing import AsyncGenerator +from collections.abc import AsyncGenerator +from typing import Any import pytest -from sqlalchemy import Column, String, exc -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine -from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base -from sqlalchemy.orm import relationship, sessionmaker +import pytest_asyncio +from sqlalchemy import String, exc +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.orm import ( + DeclarativeBase, + Mapped, + mapped_column, + relationship, +) from fastapi_users_db_sqlalchemy import ( - SQLAlchemyBaseOAuthAccountTable, - SQLAlchemyBaseUserTable, + UUID_ID, + SQLAlchemyBaseOAuthAccountTableUUID, + SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase, ) -from tests.conftest import UserDB, UserDBOAuth +from tests.conftest import DATABASE_URL def create_async_session_maker(engine: AsyncEngine): - return sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + return async_sessionmaker(engine, expire_on_commit=False) -@pytest.fixture -async def sqlalchemy_user_db() -> AsyncGenerator[SQLAlchemyUserDatabase, None]: - Base: DeclarativeMeta = declarative_base() +class Base(DeclarativeBase): + pass + + +class User(SQLAlchemyBaseUserTableUUID, Base): + first_name: Mapped[str] = mapped_column(String(255), nullable=True) + + +class OAuthBase(DeclarativeBase): + pass + - class User(SQLAlchemyBaseUserTable, Base): - first_name = Column(String, nullable=True) +class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, OAuthBase): + pass - DATABASE_URL = "sqlite+aiosqlite:///./test-sqlalchemy-user.db" - engine = create_async_engine( - DATABASE_URL, connect_args={"check_same_thread": False} +class UserOAuth(SQLAlchemyBaseUserTableUUID, OAuthBase): + first_name: Mapped[str] = mapped_column(String(255), nullable=True) + oauth_accounts: Mapped[list[OAuthAccount]] = relationship( + "OAuthAccount", lazy="joined" ) + + +@pytest_asyncio.fixture +async def sqlalchemy_user_db() -> AsyncGenerator[SQLAlchemyUserDatabase, None]: + engine = create_async_engine(DATABASE_URL) sessionmaker = create_async_session_maker(engine) async with engine.begin() as connection: await connection.run_sync(Base.metadata.create_all) - async with sessionmaker() as session: - yield SQLAlchemyUserDatabase(UserDB, session, User) + async with sessionmaker() as session: + yield SQLAlchemyUserDatabase(session, User) + async with engine.begin() as connection: await connection.run_sync(Base.metadata.drop_all) -@pytest.fixture +@pytest_asyncio.fixture async def sqlalchemy_user_db_oauth() -> AsyncGenerator[SQLAlchemyUserDatabase, None]: - Base: DeclarativeMeta = declarative_base() - - class User(SQLAlchemyBaseUserTable, Base): - first_name = Column(String, nullable=True) - oauth_accounts = relationship("OAuthAccount") - - class OAuthAccount(SQLAlchemyBaseOAuthAccountTable, Base): - pass - - DATABASE_URL = "sqlite+aiosqlite:///./test-sqlalchemy-user-oauth.db" - - engine = create_async_engine( - DATABASE_URL, connect_args={"check_same_thread": False} - ) + engine = create_async_engine(DATABASE_URL) sessionmaker = create_async_session_maker(engine) async with engine.begin() as connection: - await connection.run_sync(Base.metadata.create_all) + await connection.run_sync(OAuthBase.metadata.create_all) - async with sessionmaker() as session: - yield SQLAlchemyUserDatabase(UserDBOAuth, session, User, OAuthAccount) + async with sessionmaker() as session: + yield SQLAlchemyUserDatabase(session, UserOAuth, OAuthAccount) - await connection.run_sync(Base.metadata.drop_all) + async with engine.begin() as connection: + await connection.run_sync(OAuthBase.metadata.drop_all) @pytest.mark.asyncio -@pytest.mark.db -async def test_queries(sqlalchemy_user_db: SQLAlchemyUserDatabase[UserDB]): - user = UserDB( - email="lancelot@camelot.bt", - hashed_password="guinevere", - ) +async def test_queries(sqlalchemy_user_db: SQLAlchemyUserDatabase[User, UUID_ID]): + user_create = { + "email": "lancelot@camelot.bt", + "hashed_password": "guinevere", + } # Create - user_db = await sqlalchemy_user_db.create(user) - assert user_db.id is not None - assert user_db.is_active is True - assert user_db.is_superuser is False - assert user_db.email == user.email + user = await sqlalchemy_user_db.create(user_create) + assert user.id is not None + assert user.is_active is True + assert user.is_superuser is False + assert user.email == user_create["email"] # Update - user_db.is_superuser = True - await sqlalchemy_user_db.update(user_db) + updated_user = await sqlalchemy_user_db.update(user, {"is_superuser": True}) + assert updated_user.is_superuser is True # Get by id id_user = await sqlalchemy_user_db.get(user.id) assert id_user is not None - assert id_user.id == user_db.id + assert id_user.id == user.id assert id_user.is_superuser is True # Get by email - email_user = await sqlalchemy_user_db.get_by_email(str(user.email)) + email_user = await sqlalchemy_user_db.get_by_email(str(user_create["email"])) assert email_user is not None - assert email_user.id == user_db.id + assert email_user.id == user.id # Get by uppercased email email_user = await sqlalchemy_user_db.get_by_email("Lancelot@camelot.bt") assert email_user is not None - assert email_user.id == user_db.id + assert email_user.id == user.id # Unknown user unknown_user = await sqlalchemy_user_db.get_by_email("galahad@camelot.bt") @@ -112,36 +125,41 @@ async def test_queries(sqlalchemy_user_db: SQLAlchemyUserDatabase[UserDB]): deleted_user = await sqlalchemy_user_db.get(user.id) assert deleted_user is None + # OAuth without defined table + with pytest.raises(NotImplementedError): + await sqlalchemy_user_db.get_by_oauth_account("foo", "bar") + with pytest.raises(NotImplementedError): + await sqlalchemy_user_db.add_oauth_account(user, {}) + with pytest.raises(NotImplementedError): + oauth_account = OAuthAccount() + await sqlalchemy_user_db.update_oauth_account(user, oauth_account, {}) + @pytest.mark.asyncio -@pytest.mark.db async def test_insert_existing_email( - sqlalchemy_user_db: SQLAlchemyUserDatabase[UserDB], + sqlalchemy_user_db: SQLAlchemyUserDatabase[User, UUID_ID], ): - user = UserDB( - email="lancelot@camelot.bt", - hashed_password="guinevere", - ) - await sqlalchemy_user_db.create(user) + user_create = { + "email": "lancelot@camelot.bt", + "hashed_password": "guinevere", + } + await sqlalchemy_user_db.create(user_create) with pytest.raises(exc.IntegrityError): - await sqlalchemy_user_db.create( - UserDB(email=user.email, hashed_password="guinevere") - ) + await sqlalchemy_user_db.create(user_create) @pytest.mark.asyncio -@pytest.mark.db async def test_queries_custom_fields( - sqlalchemy_user_db: SQLAlchemyUserDatabase[UserDB], + sqlalchemy_user_db: SQLAlchemyUserDatabase[User, UUID_ID], ): """It should output custom fields in query result.""" - user = UserDB( - email="lancelot@camelot.bt", - hashed_password="guinevere", - first_name="Lancelot", - ) - await sqlalchemy_user_db.create(user) + user_create = { + "email": "lancelot@camelot.bt", + "hashed_password": "guinevere", + "first_name": "Lancelot", + } + user = await sqlalchemy_user_db.create(user_create) id_user = await sqlalchemy_user_db.get(user.id) assert id_user is not None @@ -150,43 +168,48 @@ async def test_queries_custom_fields( @pytest.mark.asyncio -@pytest.mark.db async def test_queries_oauth( - sqlalchemy_user_db_oauth: SQLAlchemyUserDatabase[UserDBOAuth], - oauth_account1, - oauth_account2, + sqlalchemy_user_db_oauth: SQLAlchemyUserDatabase[UserOAuth, UUID_ID], + oauth_account1: dict[str, Any], + oauth_account2: dict[str, Any], ): - user = UserDBOAuth( - email="lancelot@camelot.bt", - hashed_password="guinevere", - oauth_accounts=[oauth_account1, oauth_account2], - ) + user_create = { + "email": "lancelot@camelot.bt", + "hashed_password": "guinevere", + } # Create - user_db = await sqlalchemy_user_db_oauth.create(user) - assert user_db.id is not None - assert hasattr(user_db, "oauth_accounts") - assert len(user_db.oauth_accounts) == 2 + user = await sqlalchemy_user_db_oauth.create(user_create) + assert user.id is not None + + # Add OAuth account + user = await sqlalchemy_user_db_oauth.add_oauth_account(user, oauth_account1) + user = await sqlalchemy_user_db_oauth.add_oauth_account(user, oauth_account2) + assert len(user.oauth_accounts) == 2 + assert user.oauth_accounts[1].account_id == oauth_account2["account_id"] + assert user.oauth_accounts[0].account_id == oauth_account1["account_id"] # Update - user_db.oauth_accounts[0].access_token = "NEW_TOKEN" - await sqlalchemy_user_db_oauth.update(user_db) + user = await sqlalchemy_user_db_oauth.update_oauth_account( + user, user.oauth_accounts[0], {"access_token": "NEW_TOKEN"} + ) + assert user.oauth_accounts[0].access_token == "NEW_TOKEN" # Get by id id_user = await sqlalchemy_user_db_oauth.get(user.id) assert id_user is not None - assert id_user.id == user_db.id + assert id_user.id == user.id assert id_user.oauth_accounts[0].access_token == "NEW_TOKEN" # Get by email - email_user = await sqlalchemy_user_db_oauth.get_by_email(str(user.email)) + email_user = await sqlalchemy_user_db_oauth.get_by_email(user_create["email"]) assert email_user is not None - assert email_user.id == user_db.id + assert email_user.id == user.id assert len(email_user.oauth_accounts) == 2 # Get by OAuth account oauth_user = await sqlalchemy_user_db_oauth.get_by_oauth_account( - oauth_account1.oauth_name, oauth_account1.account_id + oauth_account1["oauth_name"], oauth_account1["account_id"] ) assert oauth_user is not None assert oauth_user.id == user.id