Do you code in Python and always wanted to jump on the GraphQL hype train? Could not choose between Graphene and Ariadne? We suggest trying Strawberry.
Strawberry is a code-first library with batteries included. 2.6k stars on GitHub. You can use dataclasses and pydantic-models for your types. Strawberry supports both sync and async by design.
We’ll create an app for creating and retrieving users
and their books
in this guide.
- Add user.
- Add book linked to a user.
- Retrieve books with nested users.
We store our data in database, so our app will be IO-bound. Therefore we’ll write an asycn app.
To create the app we will:
- use Strawberry with fastAPI on Python 3.10;
- as a database we’ll use postgreSQL 14 with python-lib encode/databases and asyncpg engine;
- use pydantic for models;
- write queries with cursor relay-style pagination and dataloaders;
- write mutations;
- write tests with pytest.
We’ll assume you are already familiar with the GraphQL concepts. You can learn them on the GraphQL official website.
Prepare your database
We need a running postgreSQL to run and test our app.
docker run -d --restart=always -p 5432:5432 -e POSTGRES_PASSWORD=postgres -e POSTGRES_USER=postgres --name postgres postgres:14.3-alpine
psql "postgresql://postgres:postgres@localhost:5432" -с 'create database strawberry;'
Dependencies
First things first - let’s install our dependencies. We’ll use Poetry as a dependencies manager.
pip install poetry
Initialize our app.
poetry init
We’ll use default values for everything, except dependencies. Select option no
for them.
Install dependencies:
- Strawberry – our GraphQL-library. The version with fastAPI;
- Databases – a library to work with databases. Version with asyncpg driver;
- Uvicorn – an async server to run our app;
- Yoyo-migrations – a database migrations util;
- Psycopg2-binary – an engine used to run our migrations;
- Pypika – a simple SQL query builder.
Dev dependencies:
- pytest for tests;
- pytest-asyncio for async tests;
- pytest-mock to simplify mocking;
- mypy for type-checking;
- httpx to test code with real requests.
poetry add 'strawberry-graphql[fastapi]@0.128.0' 'databases[asyncpg]' 'uvicorn[standard]' yoyo-migrations psycopg2-binary pypika
poetry add -D pytest pytest-asyncio pytest-mock mypy httpx
App structure
Let’s create our app’s structure.
mkdir migrations
touch migrations/000001.init.sql
touch migrations/000001.init.rollback.sql
touch run.py
touch settings.py
mkdir -p src/users
mkdir -p src/books
touch src/__init__.py
touch src/context.py
touch src/users/__init__.py
touch src/users/gql.py
touch src/users/models.py
touch src/books/__init__.py
touch src/books/gql.py
touch src/books/models.py
mkdir -p tests/fixtures
touch tests/__init__.py
touch tests/conftest.py
touch tests/test_users.py
touch tests/test_books.py
touch tests/fixtures/__init__.py
touch tests/fixtures/clients.py
touch tests/fixtures/graphql_client.py
touch .env
touch mypy.ini
Minimum working application
Add app’s port and connection’s arguments into .env
:
PORT=8002
POSTGRES_USER=postgres
POSTGRES_PASSWORD=postgres
POSTGRES_HOST=localhost
POSTGRES_PORT=6432
POSTGRES_DB_NAME=postgres
We’ll read from .env
in settings.py
:
import os
PORT = int(os.environ.get('PORT', 8001))
DB_USER = os.environ['POSTGRES_USER']
DB_PASSWORD = os.environ['POSTGRES_PASSWORD']
DB_SERVER = os.environ['POSTGRES_HOST']
DB_PORT = int(os.environ['POSTGRES_PORT'])
DB_NAME = os.environ['POSTGRES_DB_NAME']
CONN_TEMPLATE = (
'postgresql+asyncpg://{user}:{password}@{host}:{port}/{name}'
)
MIGRATIONS_CONN_TEMPLATE = (
'postgresql://{user}:{password}@{host}:{port}/{name}'
)
DEFAULT_LIMIT = 100
A minimum working application with a connection to the db in run.py
:
import functools as fn
import typing as tp
import databases
import settings
import strawberry
import uvicorn
from fastapi import FastAPI
from strawberry.fastapi import GraphQLRouter
from strawberry.schema.config import StrawberryConfig
@strawberry.type
class Query():
"""Query."""
@strawberry.field
def hello(self) -> str:
return 'world'
schema = strawberry.Schema(
query=Query,
config=StrawberryConfig(auto_camel_case=True),
)
async def startup_db(db: databases.Database):
await db.connect()
async def shutdown_db(db: databases.Database):
await db.disconnect()
HOOK_TYPE = tp.Optional[
tp.Sequence[tp.Callable[[], tp.Any]]
]
def get_app(
db: databases.Database,
on_startup: HOOK_TYPE = None,
on_shutdown: HOOK_TYPE = None,
) -> FastAPI:
app = FastAPI(
on_startup=[fn.partial(startup_db, db)]
if on_startup is None else on_startup,
on_shutdown=[fn.partial(shutdown_db, db)]
if on_shutdown is None else on_shutdown,
)
graphql_app = GraphQLRouter(
schema,
)
app.include_router(graphql_app, prefix='/graphql')
return app
def main() -> None:
database = databases.Database(
settings.CONN_TEMPLATE.format(
user=settings.DB_USER,
password=settings.DB_PASSWORD,
port=settings.DB_PORT,
host=settings.DB_SERVER,
name=settings.DB_NAME,
),
)
app = get_app(database)
uvicorn.run(
app,
host='0.0.0.0',
port=settings.PORT,
)
if __name__ == '__main__':
main()
Let’s launch:
poetry run python3 run.py
And open in a browser: http://localhost:8002/graphql.
GraphQL Playground should open. In the query field let’s call our only handler:
{
hello
}
As the response we’ll get:
{
"data": {
"hello": "world"
}
}
Models
Migration – migrations/000001.init.sql
:
create table users (
id bigserial primary key,
name text not null
);
create table books (
id bigserial primary key,
user_id bigint not null references users(id) on delete cascade,
title text not null
);
Rollback migration – migrations/000001.init.rollback.sql
:
drop table books;
drop table users;
Let’s describe the models, their creation, and retrieving.
We’ll use pydantic
-models to define our models and structures.
Users – src/users/models.py
:
from typing import Optional
from databases import Database
import pydantic
from pypika.dialects import PostgreSQLQuery as Query, Table
class User(pydantic.BaseModel):
id: int
name: str
class CreateUserInput(pydantic.BaseModel):
name: str
async def get_users(
db: Database,
ids: Optional[list[int]] = None,
) -> list[User]:
users_tb = Table('users')
query = Query.from_(users_tb).select(users_tb.star)
if ids is not None:
if not ids:
return []
query = query.where(
users_tb.field('id').isin(ids),
)
return [
User(**el._mapping)
for el in await db.fetch_all(query=str(query))
]
async def create_user(
db: Database,
create_input: CreateUserInput,
) -> User:
users_tb = Table('users')
query = Query.into(users_tb).columns(
tuple(create_input.dict().keys()),
).returning(users_tb.star)
for value in create_input.dict().values():
query = query.insert(value)
row = await db.fetch_one(query=str(query))
if not row:
raise ValueError('empty row returned')
return User(**row._mapping)
Books – src/books/models.py
:
from typing import Optional
from databases import Database
import pydantic
from pypika.dialects import PostgreSQLQuery as Query, Table
class Book(pydantic.BaseModel):
id: int
title: str
user_id: int
class CreateBookInput(pydantic.BaseModel):
title: str
user_id: int
async def get_books(
db: Database,
ids: Optional[list[int]] = None,
after: str | None = None,
first: int | None = None,
) -> list[Book]:
books_tb = Table('books')
query = Query.from_(books_tb).select(books_tb.star)
if ids is not None:
if not ids:
return []
query = query.where(
books_tb.field('id').isin(ids),
)
if after:
query = query.where(
books_tb.field('id').gt(int(after)),
)
if first:
query = query.limit(first)
return [
Book(**el._mapping)
for el in await db.fetch_all(query=str(query))
]
async def create_book(
db: Database,
create_input: CreateBookInput,
) -> Book:
books_tb = Table('books')
query = Query.into(books_tb).columns(
tuple(create_input.dict().keys()),
).returning(books_tb.star).insert(
list(create_input.dict().values()),
)
row = await db.fetch_one(query=str(query))
if not row:
raise ValueError('empty row returned')
return Book(**row._mapping)
Besides the filter by ids, we wrote a cursor pagination in get_books
in advance:
first
- is likelimit
in the limit-offset pagination;after
- the cursor; we’ll assume id as a cursor.
User types
Define our users’ type – UserType
.
We’ll convert our pydantic
-types to the strawberry
-types with the strawberry
converter.
strawberry.auto
stands for:
- copy fields from a model used as the
model
argument of thetype
decorator; - type is handled by the library;
- GraphQL-name is created automatically by camelCase convention.
To query users let’s create a UserQuery
type.
extend=True
tells strawberry to write our type as extend type Query
in GraphQL schema.
To create a resolver with a name user
we’ll write a user
function with the field
decorator. Function recieves id
of type int
as an argument.
Info[Context, None]
– is a Strawberry feature. We’ll talk about info and user_loader
later.
Similarly to the user type, we will write the creation type. The difference is that we use the input
decorator instead of type
.
Creation type is used for mutation. Mutation’s function is almost the same as user
resolver but with a mutation
decorator.
Annotated
is used here to change the argument’s name from automatically generated createInputGql
from create_input_gql
to just input
.
With create_input = create_input_gql.to_pydantic()
we convert strawberry
to pydantic
type. There is no actual type conversion but it runs pydantic
validation and «calms» mypy
down. :)
Resulting code in src/users/gql.py
:
from typing import Annotated, cast
import strawberry
from src.context import Context
from src.users.models import (
CreateUserInput,
User,
create_user,
)
from strawberry.types import Info
@strawberry.experimental.pydantic.type(model=User, name='User')
class UserType():
id: strawberry.auto
name: strawberry.auto
@strawberry.type(name='Query', extend=True)
class UserQuery:
@strawberry.field(name='user')
async def user(
self,
info: Info[Context, None],
id: int,
) -> UserType | None:
return await info.context.user_loader.load(id)
@strawberry.experimental.pydantic.input(
model=CreateUserInput, name='CreateUserInput',
)
class CreateUserInputType:
name: strawberry.auto
@strawberry.type(name='Mutation', extend=True)
class UserMutation:
@strawberry.mutation()
async def create_user(
self,
info: Info[Context, None],
create_input_gql: Annotated[
CreateUserInputType,
strawberry.argument(name='input'),
],
) -> UserType:
create_input = create_input_gql.to_pydantic()
return cast(
UserType,
await create_user(
db=info.context.db,
create_input=create_input,
),
)
Let’s fulfill our promise about Info
, Context
and user_loader
and fill out src/context.py
.
Context
is a bag of useful stuff – dependency injection.
We’ll put a database connection and dataloaders there.
Let’s write a users’ dataloader. We’ll write a partial
-function.
partial
allows us to inject a database connection to the dataloader in the beginning of the request.
So when we use load
function of the dataloader, we will just need to send the ID.
src/context.py
now:
from functools import partial
from typing import TYPE_CHECKING, Optional
import databases
from strawberry.dataloader import DataLoader
from strawberry.fastapi import BaseContext
from src.users.models import get_users
if TYPE_CHECKING:
from src.users.gql import UserType
class Context(BaseContext):
"""Context."""
db: databases.Database
def __init__(
self,
db: databases.Database,
user_loader: DataLoader[int, Optional['UserType']],
):
self.db = db
self.user_loader = user_loader
def get_context(db: databases.Database) -> Context:
return Context(
db=db,
user_loader=DataLoader(
load_fn=partial(get_users, db),
)
)
We’ll modify run.py
so our app would know about the context. We’ll also add mutations and user getters.
Adding context_getter
to graphql_app
:
graphql_app = GraphQLRouter(
schema,
context_getter=fn.partial(get_context, db),
)
We’ll remove hello
from Query
, but will inherit from UserQuery
:
@strawberry.type
class Query(
UserQuery,
):
"""Query."""
Mutations:
@strawberry.type
class Mutation(
UserMutation,
):
"""Mutations."""
We should pass the mutations to the scheme object:
schema = strawberry.Schema(
query=Query,
mutation=Mutation,
config=StrawberryConfig(auto_camel_case=True),
)
Result:
import functools as fn
import typing as tp
import databases
import settings
import strawberry
import uvicorn
from fastapi import FastAPI
from strawberry.fastapi import GraphQLRouter
from strawberry.schema.config import StrawberryConfig
from src.context import get_context
from src.users.gql import UserMutation, UserQuery
@strawberry.type
class Query(
UserQuery,
):
"""Query."""
@strawberry.type
class Mutation(
UserMutation,
):
"""Mutations."""
schema = strawberry.Schema(
query=Query,
mutation=Mutation,
config=StrawberryConfig(auto_camel_case=True),
)
async def startup_db(db: databases.Database):
await db.connect()
async def shutdown_db(db: databases.Database):
await db.disconnect()
HOOK_TYPE = tp.Optional[
tp.Sequence[tp.Callable[[], tp.Any]]
]
def get_app(
db: databases.Database,
on_startup: HOOK_TYPE = None,
on_shutdown: HOOK_TYPE = None,
) -> FastAPI:
app = FastAPI(
on_startup=[fn.partial(startup_db, db)]
if on_startup is None else on_startup,
on_shutdown=[fn.partial(shutdown_db, db)]
if on_shutdown is None else on_shutdown,
)
graphql_app = GraphQLRouter(
schema,
context_getter=fn.partial(get_context, db),
)
app.include_router(graphql_app, prefix='/graphql')
return app
def main() -> None:
database = databases.Database(
settings.CONN_TEMPLATE.format(
user=settings.DB_USER,
password=settings.DB_PASSWORD,
port=settings.DB_PORT,
host=settings.DB_SERVER,
name=settings.DB_NAME,
),
)
app = get_app(database)
uvicorn.run(
app,
host='0.0.0.0',
port=settings.PORT,
)
if __name__ == '__main__':
main()
User tests
We will not test our application manually. We are professionals, right?
Strawberry helps us to test our schema out of the box. But since we have dependencies, we have to prep some tools.
Let’s use the base client from strawberry and modify it so it could make requests to the real server.
- Copy
query
from the base class and leave only the async version of the query. - Implement
request
– it’ll use httpx-client to make requests to the server. - Simplify
decode
and_build_body
.
tests/fixtures/graphql_client.py
:
import json
import typing as tp
import databases
from httpx import AsyncClient
from httpx._types import HeaderTypes, RequestFiles
from strawberry.test import BaseGraphQLTestClient, Response
class GraphQLTestClient(BaseGraphQLTestClient):
def __init__(self, client: AsyncClient, db: databases.Database):
self._client = client
self.db = db
async def query(
self,
query: str,
variables: tp.Optional[tp.Dict[str, tp.Any]] = None,
headers: tp.Optional[tp.Dict[str, object]] = None,
asserts_errors: tp.Optional[bool] = True,
files: tp.Optional[tp.Dict[str, object]] = None,
) -> Response:
"""Modifying query to return only sync."""
body = self._build_body(query, variables, files)
resp = await self.request(body, headers, files)
raw_data = self._decode(resp, type='multipart' if files else 'json')
response = Response(
errors=raw_data.get('errors'),
data=raw_data.get('data'),
extensions=raw_data.get('extensions'),
)
if asserts_errors:
assert response.errors is None, response.errors
return response
async def request(
self,
body: dict[str, object],
headers: tp.Optional[dict[str, object]] = None,
files: tp.Optional[dict[str, object]] = None,
):
"""Implement actual request."""
return await self._client.post(
'/graphql/',
json=None if files else body,
data=body if files else None,
files=tp.cast(tp.Optional[RequestFiles], files),
headers=tp.cast(tp.Optional[HeaderTypes], headers),
follow_redirects=True,
)
def _build_body(
self,
query: str,
variables: tp.Optional[dict[str, tp.Mapping]] = None, # type:ignore
files: tp.Optional[dict[str, object]] = None,
) -> dict[str, object]:
"""Build body to ignore files."""
body: dict[str, object] = {'query': query}
if variables:
body['variables'] = variables
if files:
assert variables is not None
assert files is not None
file_map = self._build_multipart_file_map(variables, files)
body = {
'operations': json.dumps(body),
'map': json.dumps(file_map),
}
return body
def _decode(
self,
response,
type: tp.Literal['multipart', 'json'],
):
"""Always decode to json."""
return response.json()
Fixtures with use of client we made:
event_loop
;prepare_db
– run the migrations and create db template for faster dump-restore;recreate_db
– restore db from the template;client
– client to make requests to our server.
We’re using AsyncClient
from httpx
and inject our app to it. It allows us to make real http requests instead of mocking.
force_rollback = True
allows us to run tests in a transaction. It works faster then restoring from a template, but it may not work correctly in case of nested transactions and side-effects caused by concurrency. If something is going wrong on the database level, try to set force_rollback = False
.
So in tests/fixtures/clients.py
:
import asyncio
import databases
import pytest
import pytest_asyncio
import settings
from httpx import AsyncClient
from run import get_app
from tests.fixtures.graphql_client import GraphQLTestClient
from yoyo import get_backend, read_migrations
TEST_DB_NAME = 'test_{name}'.format(name=settings.DB_NAME)
@pytest.fixture(scope='session')
def event_loop():
return asyncio.new_event_loop()
@pytest_asyncio.fixture(autouse=True, scope='session')
async def prepare_db():
postgres_db = databases.Database(
settings.CONN_TEMPLATE.format(
user=settings.DB_USER,
password=settings.DB_PASSWORD,
port=settings.DB_PORT,
host=settings.DB_SERVER,
name='postgres',
),
)
async with postgres_db as create_conn:
await create_conn.execute(
'drop database if exists {name};'.format(name=TEST_DB_NAME),
)
await create_conn.execute(
'drop database if exists {name}_template;'.format(
name=TEST_DB_NAME,
),
)
await create_conn.execute(
'create database {name};'.format(name=TEST_DB_NAME),
)
backend = get_backend(
settings.MIGRATIONS_CONN_TEMPLATE.format(
user=settings.DB_USER,
password=settings.DB_PASSWORD,
port=settings.DB_PORT,
host=settings.DB_SERVER,
name=TEST_DB_NAME,
),
)
migrations = read_migrations('./migrations')
with backend.lock():
backend.apply_migrations(backend.to_apply(migrations))
del backend # connection is not released otherwise
async with postgres_db as template_conn:
await template_conn.execute(
'select pg_terminate_backend(pid) '
'from pg_stat_activity'
" where datname = \'{name}\'".format(
name=TEST_DB_NAME,
),
)
await template_conn.execute(
'create database {name}_template template {name};'.format(
name=TEST_DB_NAME,
),
)
try:
yield
finally:
async with postgres_db as drop_conn:
await drop_conn.execute(
'drop database {name};'.format(name=TEST_DB_NAME),
)
await drop_conn.execute(
'drop database {name}_template;'.format(name=TEST_DB_NAME),
)
async def recreate_db() -> None:
postgres_db = databases.Database(
settings.CONN_TEMPLATE.format(
user=settings.DB_USER,
password=settings.DB_PASSWORD,
port=settings.DB_PORT,
host=settings.DB_SERVER,
name='postgres',
),
)
async with postgres_db as drop_conn:
await drop_conn.execute(
'drop database {name} with (FORCE);'.format(
name=TEST_DB_NAME,
),
)
await drop_conn.execute(
'create database {name} template {name}_template;'.format(
name=TEST_DB_NAME,
),
)
@pytest_asyncio.fixture()
async def client():
# run in transaction - makes tests faster
# but it can cause problems in case of concurrency
force_rollback = True
database = databases.Database(
settings.CONN_TEMPLATE.format(
user=settings.DB_USER,
password=settings.DB_PASSWORD,
port=settings.DB_PORT,
host=settings.DB_SERVER,
name=TEST_DB_NAME,
),
force_rollback=force_rollback,
)
async with database as db:
async with AsyncClient(
app=get_app(db=database, on_startup=[], on_shutdown=[]),
base_url='http://test',
) as test_client:
graphql_client = GraphQLTestClient(test_client, db)
try:
yield graphql_client
except Exception:
pass
if not force_rollback:
await recreate_db()
Adding our fixtures to tests/conftest.py
:
pytest_plugins = (
'tests.fixtures.clients',
)
Let’s write a test for the user creation in tests/test_users.py
. It’ll test the mutation for the creation and the user type resolver:
import pytest
from tests.fixtures.clients import GraphQLTestClient
pytestmark = [
pytest.mark.asyncio,
]
create_user_query = """
mutation createUser(
$input: CreateUserInput!
) {
createUser(input: $input) {
id
}
}
"""
async def test_create_user(
client: GraphQLTestClient,
mocker,
):
resp = await client.query(
query=create_user_query,
variables={
'input': {
'name': 'John Doe',
},
},
)
assert resp.data is not None
assert resp.data['createUser'] == {'id': mocker.ANY}
Run our tests:
poetry run pytest
Book types
Mutation to create books is the same as users’, but the type and retrieving is more interesting.
Book
type will have the user
resolver to query related user.
It uses users’ dataloader to query users. User id is passed to the dataloader from root
– a pydantic-model Book
from the BookType
.
To get books we’ll implement a paginated request by the relay spec. This spec is recommended y the GraphQL official website.
PageInfo
type with pagination info will have generic-types Connection
and Edge
with results and data edges respectively. In addition to edges
from the relay spec we’ll add nodes
to the Connection
. This can be seen in Gitlab’s GraphQL API. It’s pretty neat, and sometimes convenient.
after
and first
are the arguments for get query. We talked about them in models.
src/books/gql.py
:
from typing import Generic, TypeVar, cast
import strawberry
from typing import Annotated
import strawberry
from settings import DEFAULT_LIMIT
from src.context import Context
from src.books.models import (
CreateBookInput,
Book,
create_book,
get_books,
)
from strawberry.types import Info
from src.users.gql import UserType
@strawberry.type
class PageInfo:
has_next_page: bool
has_previous_page: bool
start_cursor: str | None
end_cursor: str | None
GenericType = TypeVar('GenericType')
@strawberry.type
class Connection(Generic[GenericType]):
page_info: PageInfo
edges: list['Edge[GenericType]']
nodes: list['GenericType']
@strawberry.type
class Edge(Generic[GenericType]):
node: GenericType
cursor: str
@strawberry.experimental.pydantic.type(model=Book, name='Book')
class BookType():
id: strawberry.auto
title: strawberry.auto
@strawberry.field
async def user(
self,
root: Book,
info: Info[Context, None],
) -> UserType:
user = await info.context.user_loader.load(
key=root.user_id,
)
if not user:
raise ValueError('user not found')
return user
@strawberry.type(name='Query', extend=True)
class BookQuery:
@strawberry.field(name='books')
async def books(
self,
info: Info[Context, None],
after: str | None = None,
first: int | None = DEFAULT_LIMIT,
) -> Connection[BookType]:
first = (first if first else DEFAULT_LIMIT) + 1
edges = [
Edge(
cursor=str(book.id),
node=book,
)
for book in cast(
list[BookType],
await get_books(
db=info.context.db,
after=after,
first=first,
),
)
]
has_next_page = len(edges) == first
edges = edges[:-1]
return Connection(
page_info=PageInfo(
has_next_page=has_next_page,
has_previous_page=after is not None,
start_cursor=edges[0].cursor if edges else None,
end_cursor=edges[-1].cursor if edges else None,
),
edges=edges,
nodes=[edge.node for edge in edges],
)
@strawberry.experimental.pydantic.input(
model=CreateBookInput, name='CreateBookInput',
)
class CreateBookInputType:
title: strawberry.auto
user_id: strawberry.auto
@strawberry.type(name='Mutation', extend=True)
class BookMutation:
@strawberry.mutation()
async def create_book(
self,
info: Info[Context, None],
create_input_gql: Annotated[
CreateBookInputType,
strawberry.argument(name='input')
],
) -> BookType:
create_input = create_input_gql.to_pydantic()
return cast(
BookType,
await create_book(
db=info.context.db,
create_input=create_input,
),
)
Adding books to the schema. Modifying run.py
:
import functools as fn
import typing as tp
import databases
import settings
import strawberry
import uvicorn
from fastapi import FastAPI
from strawberry.fastapi import GraphQLRouter
from strawberry.schema.config import StrawberryConfig
from src.books.gql import BookMutation, BookQuery
from src.context import get_context
from src.users.gql import UserMutation, UserQuery
@strawberry.type
class Query(
UserQuery,
BookQuery,
):
"""Query."""
@strawberry.type
class Mutation(
BookMutation,
UserMutation,
):
"""Mutations."""
schema = strawberry.Schema(
query=Query,
mutation=Mutation,
config=StrawberryConfig(auto_camel_case=True),
)
async def startup_db(db: databases.Database):
await db.connect()
async def shutdown_db(db: databases.Database):
await db.disconnect()
HOOK_TYPE = tp.Optional[
tp.Sequence[tp.Callable[[], tp.Any]]
]
def get_app(
db: databases.Database,
on_startup: HOOK_TYPE = None,
on_shutdown: HOOK_TYPE = None,
) -> FastAPI:
app = FastAPI(
on_startup=[fn.partial(startup_db, db)]
if on_startup is None else on_startup,
on_shutdown=[fn.partial(shutdown_db, db)]
if on_shutdown is None else on_shutdown,
)
graphql_app = GraphQLRouter(
schema,
context_getter=fn.partial(get_context, db),
)
app.include_router(graphql_app, prefix='/graphql')
return app
def main() -> None:
database = databases.Database(
settings.CONN_TEMPLATE.format(
user=settings.DB_USER,
password=settings.DB_PASSWORD,
port=settings.DB_PORT,
host=settings.DB_SERVER,
name=settings.DB_NAME,
),
)
app = get_app(database)
uvicorn.run(
app,
host='0.0.0.0',
port=settings.PORT,
)
if __name__ == '__main__':
main()
Testing all this stuff in tests/test_books.py
:
import asyncio
import pytest
from src.books.models import CreateBookInput, create_book
from src.users.models import CreateUserInput, create_user
from tests.fixtures.clients import GraphQLTestClient
from tests.test_users import create_user_query
pytestmark = [
pytest.mark.asyncio,
]
create_book_query = """
mutation createBook(
$input: CreateBookInput!
) {
createBook(input: $input) {
id
user {
id
}
}
}
"""
books_query = """
query books(
$after: String
$first: Int
) {
books(
after: $after
first: $first
) {
pageInfo {
hasNextPage
hasPreviousPage
startCursor
endCursor
}
edges {
cursor
node {
id
}
}
nodes {
id
}
}
}
"""
async def test_create_book(
client: GraphQLTestClient,
mocker,
):
create_user_resp = await client.query(
query=create_user_query,
variables={
'input': {
'name': 'Ayn Rand',
},
},
)
user_id = create_user_resp.data['createUser']['id'] # type: ignore
resp = await client.query(
query=create_book_query,
variables={
'input': {
'title': 'Atlas shrugged',
'userId': user_id,
},
},
)
assert resp.data is not None
assert resp.data['createBook'] == {
'id': mocker.ANY,
'user': {
'id': user_id,
}
}
async def test_get_books(
client: GraphQLTestClient,
):
user_inputs = [
CreateUserInput(name='Ayn Rand'),
CreateUserInput(name='Fyodor Dostoevsky'),
CreateUserInput(name='J.K. Rowling'),
]
user_1, user_2, user_3 = await asyncio.gather(
*(create_user(
db=client.db,
create_input=user_input,
) for user_input in user_inputs)
)
book_inputs = [
CreateBookInput(title='Atlas Shrugged', user_id=user_1.id),
CreateBookInput(title='Anthem', user_id=user_1.id),
CreateBookInput(title='Idiot', user_id=user_2.id),
CreateBookInput(title='Demons', user_id=user_2.id),
CreateBookInput(title='Crime and Punishment', user_id=user_2.id),
CreateBookInput(
title='Harry Potter and the Philosopher Stone',
user_id=user_3.id
),
CreateBookInput(
title='Harry Potter and the Chamber of Secrets',
user_id=user_3.id
),
CreateBookInput(
title='Harry Potter and the Prisoner of Azkaban',
user_id=user_3.id
),
]
_, book_2, book_3, book_4, book_5, *_ = await asyncio.gather(
*(create_book(
db=client.db,
create_input=book_input,
) for book_input in book_inputs)
)
books_resp = await client.query(
query=books_query,
variables={
'after': str(book_2.id),
'first': 3,
},
)
assert books_resp.data
assert books_resp.data['books'] == {
'pageInfo': {
'hasNextPage': True,
'hasPreviousPage': True,
'startCursor': str(book_3.id),
'endCursor': str(book_5.id),
},
'edges': [
{
'cursor': str(book_3.id),
'node': {'id': book_3.id},
},
{
'cursor': str(book_4.id),
'node': {'id': book_4.id},
},
{
'cursor': str(book_5.id),
'node': {'id': book_5.id},
}
],
'nodes': [
{'id': book_3.id},
{'id': book_4.id},
{'id': book_5.id},
],
}
You can find the whole project in our public repo.
Strawberry is actively developed, so documentation is often lacking. The lack of documentation is made up for by the active community on Discord. Maintainers are online every day, ready to help on all issues.