Skip to content
Snippets Groups Projects
Verified Commit af8fb8d6 authored by Nik | Klampfradler's avatar Nik | Klampfradler
Browse files

IMplement PubSub between regular requests and subscriptions

parent 1546f65b
No related branches found
No related tags found
1 merge request!1090Draft: Resolve "Enable async framework (channels)"
Pipeline #107855 failed
......@@ -20,6 +20,7 @@ from .util.apps import AppConfig
from .util.core_helpers import (
create_default_celery_schedule,
get_or_create_favicon,
get_redis_connection,
get_site_preferences,
has_person,
)
......@@ -79,6 +80,12 @@ class CoreConfig(AppConfig):
plugin_dir.register(MediaBackupAgeHealthCheck)
plugin_dir.register(BackupJobHealthCheck)
# Connect signals for django-impersonate
from impersonate.signals import session_begin, session_end # noqa
session_begin.connect(self.on_impersonate)
session_end.connect(self.on_impersonate_stop)
@classmethod
def _load_data_checks(cls):
"""Get all data checks from all loaded models."""
......@@ -156,6 +163,51 @@ class CoreConfig(AppConfig):
# Save the associated person to pick up defaults
user.person.save()
# Broadcast change of user/session to other processes
get_redis_connection("default").publish(
f"sessions.{request.session.session_key}", "user_logged_in"
)
get_redis_connection("default").publish(f"users.{request.user.pk}", "user_logged_in")
def user_logged_out(
self, sender: type, request: Optional[HttpRequest], user: "User", **kwargs
) -> None:
# Broadcast change of user/session to other processes
get_redis_connection("default").publish(
f"sessions.{request.session.session_key}", "user_logged_out"
)
get_redis_connection("default").publish(f"users.{request.user.pk}", "user_logged_out")
def on_impersonate(
self,
sender: None,
impersonator: "User",
impersonating: "User",
request: HttpRequest,
**kwargs,
) -> None:
# Broadcast change of user/session to other processes
get_redis_connection("default").publish(
f"sessions.{request.session.session_key}", f"impersonate:{impersonating.pk}"
)
get_redis_connection("default").publish(
f"users.{impersonator.pk}", f"impersonate:{impersonating.pk}"
)
def on_impersonate_stop(
self,
sender: None,
impersonator: "User",
impersonating: "User",
request: HttpRequest,
**kwargs,
) -> None:
# Broadcast change of user/session to other processes
get_redis_connection("default").publish(
f"sessions.{request.session.session_key}", "impersonate_stop"
)
get_redis_connection("default").publish(f"users.{impersonator.pk}", "impersonate_stop")
@classmethod
def get_all_scopes(cls) -> dict[str, str]:
scopes = {
......
import asyncio
from django.apps import apps
from django.contrib.auth import get_user_model
from django.contrib.messages import get_messages
from django.contrib.sessions.backends.db import SessionStore
from django.core.exceptions import PermissionDenied
from django.db.models import Q
from django.utils import timezone
......@@ -14,7 +16,13 @@ from haystack.utils.loading import UnifiedIndex
from ..models import CustomMenu, Notification, PDFFile, Person, TaskUserAssignment
from ..util.apps import AppConfig
from ..util.core_helpers import get_allowed_object_ids, get_app_module, get_app_packages, has_person
from ..util.core_helpers import (
get_allowed_object_ids,
get_app_module,
get_app_packages,
get_redis_connection,
has_person,
)
from .celery_progress import CeleryProgressFetchedMutation, CeleryProgressType
from .custom_menu import CustomMenuType
from .group import GroupType # noqa
......@@ -157,11 +165,54 @@ class Mutation(graphene.ObjectType):
class Subscription(graphene.ObjectType):
now = graphene.DateTime()
who_am_i = graphene.Field(UserType)
async def subscribe_now(root, info):
while True:
yield timezone.now()
await asyncio.sleep(1)
async def subscribe_who_am_i(root, info):
# FIXME This needs revisiting once we migrate to nativ OAuth for Vue client
# Subscribe to session change events
redis = get_redis_connection("default")
pubsub = redis.pubsub()
print(dir(info.context))
subs = [f"sessions.{info.context.session.session_key}"]
if info.context.user and info.context.user.is_authenticated:
subs += [f"users.{info.context.user.pk}"]
pubsub.subscribe(*subs)
for msg in pubsub.listen():
if msg["type"] == "subscribe":
yield info.context.user
elif msg["type"] == "unsubscribe":
return
elif msg["type"] == "message" and msg["data"].startswith(b"user_logged_"):
# The user logged in or out. In both cases, Django will issue
# a new session ID. Unfortunately, the signals are sent with
# the new session, so we have to check for the old session
# manually to find out whether our session was torn down.
was_ours = SessionStore().exists(session_key=info.context.session.session_key)
if was_ours:
# Hang up to force the client to use its new session cookie
return
elif (
msg["type"] == "message"
and msg["channel"].startswith(b"sessions.")
and msg["data"].startswith(b"impersonate:")
):
# Get impersonated user
pk = int(msg["data"].split(b":")[1])
u = get_user_model().get(pk=pk)
yield u
elif (
msg["type"] == "message"
and msg["channel"].startswith(b"sessions.")
and msg["data"] == b"impersonate_stop"
):
yield info.context.user
def build_global_schema():
"""Build global GraphQL schema from all apps."""
......@@ -188,7 +239,9 @@ def build_global_schema():
GlobalMutation = type("GlobalMutation", tuple(mutation_bases), {})
GlobalSubscription = type("GlobalSubscription", tuple(subscription_bases), {})
return graphene.Schema(query=GlobalQuery, mutation=GlobalMutation, subscription=GlobalSubscription)
return graphene.Schema(
query=GlobalQuery, mutation=GlobalMutation, subscription=GlobalSubscription
)
schema = build_global_schema()
......@@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, Optional, Sequence, Union
from warnings import warn
from django.conf import settings
from django.core.cache import caches
from django.core.exceptions import ImproperlyConfigured
from django.core.files import File
from django.db.models import Model, QuerySet
......@@ -21,6 +22,7 @@ from django.utils.module_loading import import_string
from cachalot.api import invalidate
from cachalot.signals import post_invalidation
from cache_memoize import cache_memoize
from redis import Redis
def copyright_years(years: Sequence[int], separator: str = ", ", joiner: str = "") -> str:
......@@ -403,6 +405,11 @@ def create_default_celery_schedule():
)
def get_redis_connection(cache: str = "default") -> Redis:
"""Get a native redis connection from a cache name."""
return caches[cache]._cache.get_client()
class OOTRouter:
"""Database router for operations that should run out of transaction.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment