Midway point in task refactor - changing direction

This commit is contained in:
Andrew Godwin 2022-11-08 23:06:29 -07:00
parent 8a0a755889
commit 61c324508e
24 changed files with 698 additions and 241 deletions

View file

@ -1,21 +0,0 @@
from django.contrib import admin
from miniq.models import Task
@admin.register(Task)
class TaskAdmin(admin.ModelAdmin):
list_display = ["id", "created", "type", "subject", "completed", "failed"]
ordering = ["-created"]
actions = ["reset"]
@admin.action(description="Reset Task")
def reset(self, request, queryset):
queryset.update(
failed=None,
completed=None,
locked=None,
locked_by=None,
error=None,
)

View file

@ -1,48 +0,0 @@
# Generated by Django 4.1.3 on 2022-11-07 04:19
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = []
operations = [
migrations.CreateModel(
name="Task",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
(
"type",
models.CharField(
choices=[
("identity_fetch", "Identity Fetch"),
("inbox_item", "Inbox Item"),
("follow_request", "Follow Request"),
("follow_acknowledge", "Follow Acknowledge"),
],
max_length=500,
),
),
("priority", models.IntegerField(default=0)),
("subject", models.TextField()),
("payload", models.JSONField(blank=True, null=True)),
("error", models.TextField(blank=True, null=True)),
("created", models.DateTimeField(auto_now_add=True)),
("completed", models.DateTimeField(blank=True, null=True)),
("failed", models.DateTimeField(blank=True, null=True)),
("locked", models.DateTimeField(blank=True, null=True)),
("locked_by", models.CharField(blank=True, max_length=500, null=True)),
],
),
]

View file

@ -1,71 +0,0 @@
from typing import Optional
from django.db import models, transaction
from django.utils import timezone
class Task(models.Model):
"""
A task that must be done by a queue processor
"""
class TypeChoices(models.TextChoices):
identity_fetch = "identity_fetch"
inbox_item = "inbox_item"
follow_request = "follow_request"
follow_acknowledge = "follow_acknowledge"
type = models.CharField(max_length=500, choices=TypeChoices.choices)
priority = models.IntegerField(default=0)
subject = models.TextField()
payload = models.JSONField(blank=True, null=True)
error = models.TextField(blank=True, null=True)
created = models.DateTimeField(auto_now_add=True)
completed = models.DateTimeField(blank=True, null=True)
failed = models.DateTimeField(blank=True, null=True)
locked = models.DateTimeField(blank=True, null=True)
locked_by = models.CharField(max_length=500, blank=True, null=True)
def __str__(self):
return f"{self.id}/{self.type}({self.subject})"
@classmethod
def get_one_available(cls, processor_id) -> Optional["Task"]:
"""
Gets one task off the list while reserving it, atomically.
"""
with transaction.atomic():
next_task = cls.objects.filter(locked__isnull=True).first()
if next_task is None:
return None
next_task.locked = timezone.now()
next_task.locked_by = processor_id
next_task.save()
return next_task
@classmethod
def submit(cls, type, subject: str, payload=None, deduplicate=True):
# Deduplication is done against tasks that have not started yet only,
# and only on tasks without payloads
if deduplicate and not payload:
if cls.objects.filter(
type=type,
subject=subject,
completed__isnull=True,
failed__isnull=True,
locked__isnull=True,
).exists():
return
cls.objects.create(type=type, subject=subject, payload=payload)
async def complete(self):
await self.__class__.objects.filter(id=self.id).aupdate(
completed=timezone.now()
)
async def fail(self, error):
await self.__class__.objects.filter(id=self.id).aupdate(
failed=timezone.now(),
error=error,
)

View file

@ -1,34 +0,0 @@
import traceback
from users.tasks.follow import handle_follow_request
from users.tasks.identity import handle_identity_fetch
from users.tasks.inbox import handle_inbox_item
class TaskHandler:
handlers = {
"identity_fetch": handle_identity_fetch,
"inbox_item": handle_inbox_item,
"follow_request": handle_follow_request,
}
def __init__(self, task):
self.task = task
self.subject = self.task.subject
self.payload = self.task.payload
async def handle(self):
try:
print(f"Task {self.task}: Starting")
if self.task.type not in self.handlers:
raise ValueError(f"Cannot handle type {self.task.type}")
await self.handlers[self.task.type](
self,
)
await self.task.complete()
print(f"Task {self.task}: Complete")
except BaseException as e:
print(f"Task {self.task}: Error {e}")
traceback.print_exc()
await self.task.fail(f"{e}\n\n" + traceback.format_exc())

View file

@ -1,51 +0,0 @@
import asyncio
import time
import uuid
from asgiref.sync import sync_to_async
from django.http import HttpResponse
from django.views import View
from miniq.models import Task
from miniq.tasks import TaskHandler
class QueueProcessor(View):
"""
A view that takes some items off the queue and processes them.
Tries to limit its own runtime so it's within HTTP timeout limits.
"""
START_TIMEOUT = 30
TOTAL_TIMEOUT = 60
LOCK_TIMEOUT = 200
MAX_TASKS = 20
async def get(self, request):
start_time = time.monotonic()
processor_id = uuid.uuid4().hex
handled = 0
self.tasks = []
# For the first time period, launch tasks
while (time.monotonic() - start_time) < self.START_TIMEOUT:
# Remove completed tasks
self.tasks = [t for t in self.tasks if not t.done()]
# See if there's a new task
if len(self.tasks) < self.MAX_TASKS:
# Pop a task off the queue and run it
task = await sync_to_async(Task.get_one_available)(processor_id)
if task is not None:
self.tasks.append(asyncio.create_task(TaskHandler(task).handle()))
handled += 1
# Prevent busylooping
await asyncio.sleep(0.01)
# TODO: Clean up old locks here
# Then wait for tasks to finish
while (time.monotonic() - start_time) < self.TOTAL_TIMEOUT:
# Remove completed tasks
self.tasks = [t for t in self.tasks if not t.done()]
if not self.tasks:
break
# Prevent busylooping
await asyncio.sleep(1)
return HttpResponse(f"{handled} tasks handled")

8
stator/admin.py Normal file
View file

@ -0,0 +1,8 @@
from django.contrib import admin
from stator.models import StatorTask
@admin.register(StatorTask)
class DomainAdmin(admin.ModelAdmin):
list_display = ["id", "model_label", "instance_pk", "locked_until"]

View file

@ -1,6 +1,6 @@
from django.apps import AppConfig from django.apps import AppConfig
class MiniqConfig(AppConfig): class StatorConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField" default_auto_field = "django.db.models.BigAutoField"
name = "miniq" name = "stator"

162
stator/graph.py Normal file
View file

@ -0,0 +1,162 @@
import datetime
from functools import wraps
from typing import Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union
from django.db import models
from django.utils import timezone
class StateGraph:
"""
Represents a graph of possible states and transitions to attempt on them.
Does not support subclasses of existing graphs yet.
"""
states: ClassVar[Dict[str, "State"]]
choices: ClassVar[List[Tuple[str, str]]]
initial_state: ClassVar["State"]
terminal_states: ClassVar[Set["State"]]
def __init_subclass__(cls) -> None:
# Collect state memebers
cls.states = {}
for name, value in cls.__dict__.items():
if name in ["__module__", "__doc__", "states"]:
pass
elif name in ["initial_state", "terminal_states", "choices"]:
raise ValueError(f"Cannot name a state {name} - this is reserved")
elif isinstance(value, State):
value._add_to_graph(cls, name)
elif callable(value) or isinstance(value, classmethod):
pass
else:
raise ValueError(
f"Graph has item {name} of unallowed type {type(value)}"
)
# Check the graph layout
terminal_states = set()
initial_state = None
for state in cls.states.values():
if state.initial:
if initial_state:
raise ValueError(
f"The graph has more than one initial state: {initial_state} and {state}"
)
initial_state = state
if state.terminal:
terminal_states.add(state)
if initial_state is None:
raise ValueError("The graph has no initial state")
cls.initial_state = initial_state
cls.terminal_states = terminal_states
# Generate choices
cls.choices = [(name, name) for name in cls.states.keys()]
class State:
"""
Represents an individual state
"""
def __init__(self, try_interval: float = 300):
self.try_interval = try_interval
self.parents: Set["State"] = set()
self.children: Dict["State", "Transition"] = {}
def _add_to_graph(self, graph: StateGraph, name: str):
self.graph = graph
self.name = name
self.graph.states[name] = self
def __repr__(self):
return f"<State {self.name}>"
def add_transition(
self,
other: "State",
handler: Optional[Union[str, Callable]] = None,
priority: int = 0,
) -> Callable:
def decorator(handler: Union[str, Callable]):
self.children[other] = Transition(
self,
other,
handler,
priority=priority,
)
other.parents.add(self)
# All handlers should be class methods, so do that automatically.
if callable(handler):
return classmethod(handler)
# If we're not being called as a decorator, invoke it immediately
if handler is not None:
decorator(handler)
return decorator
def add_manual_transition(self, other: "State"):
self.children[other] = ManualTransition(self, other)
other.parents.add(self)
@property
def initial(self):
return not self.parents
@property
def terminal(self):
return not self.children
def transitions(self, automatic_only=False) -> List["Transition"]:
"""
Returns all transitions from this State in priority order
"""
if automatic_only:
transitions = [t for t in self.children.values() if t.automatic]
else:
transitions = self.children.values()
return sorted(transitions, key=lambda t: t.priority, reverse=True)
class Transition:
"""
A possible transition from one state to another
"""
def __init__(
self,
from_state: State,
to_state: State,
handler: Union[str, Callable],
priority: int = 0,
):
self.from_state = from_state
self.to_state = to_state
self.handler = handler
self.priority = priority
self.automatic = True
def get_handler(self) -> Callable:
"""
Returns the handler (it might need resolving from a string)
"""
if isinstance(self.handler, str):
self.handler = getattr(self.from_state.graph, self.handler)
return self.handler
class ManualTransition(Transition):
"""
A possible transition from one state to another that cannot be done by
the stator task runner, and must come from an external source.
"""
def __init__(
self,
from_state: State,
to_state: State,
):
self.from_state = from_state
self.to_state = to_state
self.handler = None
self.priority = 0
self.automatic = False

View file

@ -0,0 +1,31 @@
# Generated by Django 4.1.3 on 2022-11-09 05:46
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = []
operations = [
migrations.CreateModel(
name="StatorTask",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("model_label", models.CharField(max_length=200)),
("instance_pk", models.CharField(max_length=200)),
("locked_until", models.DateTimeField(blank=True, null=True)),
("priority", models.IntegerField(default=0)),
],
),
]

191
stator/models.py Normal file
View file

@ -0,0 +1,191 @@
import datetime
from functools import reduce
from typing import Type, cast
from asgiref.sync import sync_to_async
from django.apps import apps
from django.db import models, transaction
from django.utils import timezone
from django.utils.functional import classproperty
from stator.graph import State, StateGraph
class StateField(models.CharField):
"""
A special field that automatically gets choices from a state graph
"""
def __init__(self, graph: Type[StateGraph], **kwargs):
# Sensible default for state length
kwargs.setdefault("max_length", 100)
# Add choices and initial
self.graph = graph
kwargs["choices"] = self.graph.choices
kwargs["default"] = self.graph.initial_state.name
super().__init__(**kwargs)
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
kwargs["graph"] = self.graph
return name, path, args, kwargs
def from_db_value(self, value, expression, connection):
if value is None:
return value
return self.graph.states[value]
def to_python(self, value):
if isinstance(value, State) or value is None:
return value
return self.graph.states[value]
def get_prep_value(self, value):
if isinstance(value, State):
return value.name
return value
class StatorModel(models.Model):
"""
A model base class that has a state machine backing it, with tasks to work
out when to move the state to the next one.
You need to provide a "state" field as an instance of StateField on the
concrete model yourself.
"""
# When the state last actually changed, or the date of instance creation
state_changed = models.DateTimeField(auto_now_add=True)
# When the last state change for the current state was attempted
# (and not successful, as this is cleared on transition)
state_attempted = models.DateTimeField(blank=True, null=True)
class Meta:
abstract = True
@classmethod
def schedule_overdue(cls, now=None) -> models.QuerySet:
"""
Finds instances of this model that need to run and schedule them.
"""
q = models.Q()
for transition in cls.state_graph.transitions(automatic_only=True):
q = q | transition.get_query(now=now)
return cls.objects.filter(q)
@classproperty
def state_graph(cls) -> Type[StateGraph]:
return cls._meta.get_field("state").graph
def schedule_transition(self, priority: int = 0):
"""
Adds this instance to the queue to get its state transition attempted.
The scheduler will call this, but you can also call it directly if you
know it'll be ready and want to lower latency.
"""
StatorTask.schedule_for_execution(self, priority=priority)
async def attempt_transition(self):
"""
Attempts to transition the current state by running its handler(s).
"""
# Try each transition in priority order
for transition in self.state_graph.states[self.state].transitions(
automatic_only=True
):
success = await transition.get_handler()(self)
if success:
await self.perform_transition(transition.to_state.name)
return
await self.__class__.objects.filter(pk=self.pk).aupdate(
state_attempted=timezone.now()
)
async def perform_transition(self, state_name):
"""
Transitions the instance to the given state name
"""
if state_name not in self.state_graph.states:
raise ValueError(f"Invalid state {state_name}")
await self.__class__.objects.filter(pk=self.pk).aupdate(
state=state_name,
state_changed=timezone.now(),
state_attempted=None,
)
class StatorTask(models.Model):
"""
The model that we use for an internal scheduling queue.
Entries in this queue are up for checking and execution - it also performs
locking to ensure we get closer to exactly-once execution (but we err on
the side of at-least-once)
"""
# appname.modelname (lowercased) label for the model this represents
model_label = models.CharField(max_length=200)
# The primary key of that model (probably int or str)
instance_pk = models.CharField(max_length=200)
# Locking columns (no runner ID, as we have no heartbeats - all runners
# only live for a short amount of time anyway)
locked_until = models.DateTimeField(null=True, blank=True)
# Basic total ordering priority - higher is more important
priority = models.IntegerField(default=0)
def __str__(self):
return f"#{self.pk}: {self.model_label}.{self.instance_pk}"
@classmethod
def schedule_for_execution(cls, model_instance: StatorModel, priority: int = 0):
# We don't do a transaction here as it's fine to occasionally double up
model_label = model_instance._meta.label_lower
pk = model_instance.pk
# TODO: Increase priority of existing if present
if not cls.objects.filter(
model_label=model_label, instance_pk=pk, locked__isnull=True
).exists():
StatorTask.objects.create(
model_label=model_label,
instance_pk=pk,
priority=priority,
)
@classmethod
def get_for_execution(cls, number: int, lock_expiry: datetime.datetime):
"""
Returns up to `number` tasks for execution, having locked them.
"""
with transaction.atomic():
selected = list(
cls.objects.filter(locked_until__isnull=True)[
:number
].select_for_update()
)
cls.objects.filter(pk__in=[i.pk for i in selected]).update(
locked_until=timezone.now()
)
return selected
@classmethod
async def aget_for_execution(cls, number: int, lock_expiry: datetime.datetime):
return await sync_to_async(cls.get_for_execution)(number, lock_expiry)
@classmethod
async def aclean_old_locks(cls):
await cls.objects.filter(locked_until__lte=timezone.now()).aupdate(
locked_until=None
)
async def aget_model_instance(self) -> StatorModel:
model = apps.get_model(self.model_label)
return cast(StatorModel, await model.objects.aget(pk=self.pk))
async def adelete(self):
self.__class__.objects.adelete(pk=self.pk)

69
stator/runner.py Normal file
View file

@ -0,0 +1,69 @@
import asyncio
import datetime
import time
import uuid
from typing import List, Type
from asgiref.sync import sync_to_async
from django.db import transaction
from django.utils import timezone
from stator.models import StatorModel, StatorTask
class StatorRunner:
"""
Runs tasks on models that are looking for state changes.
Designed to run in a one-shot mode, living inside a request.
"""
START_TIMEOUT = 30
TOTAL_TIMEOUT = 60
LOCK_TIMEOUT = 120
MAX_TASKS = 30
def __init__(self, models: List[Type[StatorModel]]):
self.models = models
self.runner_id = uuid.uuid4().hex
async def run(self):
start_time = time.monotonic()
self.handled = 0
self.tasks = []
# Clean up old locks
await StatorTask.aclean_old_locks()
# Examine what needs scheduling
# For the first time period, launch tasks
while (time.monotonic() - start_time) < self.START_TIMEOUT:
self.remove_completed_tasks()
space_remaining = self.MAX_TASKS - len(self.tasks)
# Fetch new tasks
if space_remaining > 0:
for new_task in await StatorTask.aget_for_execution(
space_remaining,
timezone.now() + datetime.timedelta(seconds=self.LOCK_TIMEOUT),
):
self.tasks.append(asyncio.create_task(self.run_task(new_task)))
self.handled += 1
# Prevent busylooping
await asyncio.sleep(0.01)
# Then wait for tasks to finish
while (time.monotonic() - start_time) < self.TOTAL_TIMEOUT:
self.remove_completed_tasks()
if not self.tasks:
break
# Prevent busylooping
await asyncio.sleep(1)
return self.handled
async def run_task(self, task: StatorTask):
# Resolve the model instance
model_instance = await task.aget_model_instance()
await model_instance.attempt_transition()
# Remove ourselves from the database as complete
await task.adelete()
def remove_completed_tasks(self):
self.tasks = [t for t in self.tasks if not t.done()]

View file

@ -0,0 +1,66 @@
import pytest
from stator.graph import State, StateGraph
def test_declare():
"""
Tests a basic graph declaration and various kinds of handler
lookups.
"""
fake_handler = lambda: True
class TestGraph(StateGraph):
initial = State()
second = State()
third = State()
fourth = State()
final = State()
initial.add_transition(second, 60, handler=fake_handler)
second.add_transition(third, 60, handler="check_third")
def check_third(cls):
return True
@third.add_transition(fourth, 60)
def check_fourth(cls):
return True
fourth.add_manual_transition(final)
assert TestGraph.initial_state == TestGraph.initial
assert TestGraph.terminal_states == {TestGraph.final}
assert TestGraph.initial.children[TestGraph.second].get_handler() == fake_handler
assert (
TestGraph.second.children[TestGraph.third].get_handler()
== TestGraph.check_third
)
assert (
TestGraph.third.children[TestGraph.fourth].get_handler().__name__
== "check_fourth"
)
def test_bad_declarations():
"""
Tests that you can't declare an invalid graph.
"""
# More than one initial state
with pytest.raises(ValueError):
class TestGraph(StateGraph):
initial = State()
initial2 = State()
# No initial states
with pytest.raises(ValueError):
class TestGraph(StateGraph):
loop = State()
loop2 = State()
loop.add_transition(loop2, 1, handler="fake")
loop2.add_transition(loop, 1, handler="fake")

17
stator/views.py Normal file
View file

@ -0,0 +1,17 @@
from django.http import HttpResponse
from django.views import View
from stator.runner import StatorRunner
from users.models import Follow
class RequestRunner(View):
"""
Runs a Stator runner within a HTTP request. For when you're on something
serverless.
"""
async def get(self, request):
runner = StatorRunner([Follow])
handled = await runner.run()
return HttpResponse(f"Handled {handled}")

View file

@ -26,7 +26,7 @@ INSTALLED_APPS = [
"core", "core",
"statuses", "statuses",
"users", "users",
"miniq", "stator",
] ]
MIDDLEWARE = [ MIDDLEWARE = [

View file

@ -2,7 +2,7 @@ from django.contrib import admin
from django.urls import path from django.urls import path
from core import views as core from core import views as core
from miniq import views as miniq from stator import views as stator
from users.views import auth, identity from users.views import auth, identity
urlpatterns = [ urlpatterns = [
@ -22,7 +22,7 @@ urlpatterns = [
# Well-known endpoints # Well-known endpoints
path(".well-known/webfinger", identity.Webfinger.as_view()), path(".well-known/webfinger", identity.Webfinger.as_view()),
# Task runner # Task runner
path(".queue/process/", miniq.QueueProcessor.as_view()), path(".stator/runner/", stator.RequestRunner.as_view()),
# Django admin # Django admin
path("djadmin/", admin.site.urls), path("djadmin/", admin.site.urls),
] ]

View file

@ -25,4 +25,4 @@ class IdentityAdmin(admin.ModelAdmin):
@admin.register(Follow) @admin.register(Follow)
class FollowAdmin(admin.ModelAdmin): class FollowAdmin(admin.ModelAdmin):
list_display = ["id", "source", "target", "requested", "accepted"] list_display = ["id", "source", "target", "state"]

View file

@ -0,0 +1,44 @@
# Generated by Django 4.1.3 on 2022-11-07 19:22
import django.utils.timezone
from django.db import migrations, models
import stator.models
import users.models.follow
class Migration(migrations.Migration):
dependencies = [
("users", "0001_initial"),
]
operations = [
migrations.AddField(
model_name="follow",
name="state",
field=stator.models.StateField(
choices=[
("pending", "pending"),
("requested", "requested"),
("accepted", "accepted"),
],
default="pending",
graph=users.models.follow.FollowStates,
max_length=100,
),
),
migrations.AddField(
model_name="follow",
name="state_attempted",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="follow",
name="state_changed",
field=models.DateTimeField(
auto_now_add=True, default=django.utils.timezone.now
),
preserve_default=False,
),
]

View file

@ -0,0 +1,31 @@
# Generated by Django 4.1.3 on 2022-11-08 03:58
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("users", "0002_follow_state_follow_state_attempted_and_more"),
]
operations = [
migrations.RemoveField(
model_name="follow",
name="accepted",
),
migrations.RemoveField(
model_name="follow",
name="requested",
),
migrations.AddField(
model_name="follow",
name="state_locked",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="follow",
name="state_runner",
field=models.CharField(blank=True, max_length=100, null=True),
),
]

View file

@ -0,0 +1,21 @@
# Generated by Django 4.1.3 on 2022-11-09 05:15
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("users", "0003_remove_follow_accepted_remove_follow_requested_and_more"),
]
operations = [
migrations.RemoveField(
model_name="follow",
name="state_locked",
),
migrations.RemoveField(
model_name="follow",
name="state_runner",
),
]

View file

@ -2,10 +2,23 @@ from typing import Optional
from django.db import models from django.db import models
from miniq.models import Task from stator.models import State, StateField, StateGraph, StatorModel
class Follow(models.Model): class FollowStates(StateGraph):
pending = State(try_interval=3600)
requested = State()
accepted = State()
@pending.add_transition(requested)
async def try_request(cls, instance):
print("Would have tried to follow")
return False
requested.add_manual_transition(accepted)
class Follow(StatorModel):
""" """
When one user (the source) follows other (the target) When one user (the source) follows other (the target)
""" """
@ -24,8 +37,7 @@ class Follow(models.Model):
uri = models.CharField(blank=True, null=True, max_length=500) uri = models.CharField(blank=True, null=True, max_length=500)
note = models.TextField(blank=True, null=True) note = models.TextField(blank=True, null=True)
requested = models.BooleanField(default=False) state = StateField(FollowStates)
accepted = models.BooleanField(default=False)
created = models.DateTimeField(auto_now_add=True) created = models.DateTimeField(auto_now_add=True)
updated = models.DateTimeField(auto_now=True) updated = models.DateTimeField(auto_now=True)
@ -50,17 +62,15 @@ class Follow(models.Model):
(which can be local or remote). (which can be local or remote).
""" """
if not source.local: if not source.local:
raise ValueError("You cannot initiate follows on a remote Identity") raise ValueError("You cannot initiate follows from a remote Identity")
try: try:
follow = Follow.objects.get(source=source, target=target) follow = Follow.objects.get(source=source, target=target)
except Follow.DoesNotExist: except Follow.DoesNotExist:
follow = Follow.objects.create(source=source, target=target, uri="") follow = Follow.objects.create(source=source, target=target, uri="")
follow.uri = source.actor_uri + f"follow/{follow.pk}/" follow.uri = source.actor_uri + f"follow/{follow.pk}/"
# TODO: Local follow approvals
if target.local: if target.local:
follow.requested = True follow.state = FollowStates.accepted
follow.accepted = True
else:
Task.submit("follow_request", str(follow.pk))
follow.save() follow.save()
return follow return follow

View file

@ -27,3 +27,36 @@ async def handle_follow_request(task_handler):
if response.status_code >= 400: if response.status_code >= 400:
raise ValueError(f"Request error: {response.status_code} {response.content}") raise ValueError(f"Request error: {response.status_code} {response.content}")
await Follow.objects.filter(pk=follow.pk).aupdate(requested=True) await Follow.objects.filter(pk=follow.pk).aupdate(requested=True)
def send_follow_undo(id):
"""
Request a follow from a remote server
"""
follow = Follow.objects.select_related("source", "source__domain", "target").get(
pk=id
)
# Construct the request
request = canonicalise(
{
"@context": "https://www.w3.org/ns/activitystreams",
"id": follow.uri + "#undo",
"type": "Undo",
"actor": follow.source.actor_uri,
"object": {
"id": follow.uri,
"type": "Follow",
"actor": follow.source.actor_uri,
"object": follow.target.actor_uri,
},
}
)
# Sign it and send it
from asgiref.sync import async_to_sync
response = async_to_sync(HttpSignature.signed_request)(
follow.target.inbox_uri, request, follow.source
)
if response.status_code >= 400:
raise ValueError(f"Request error: {response.status_code} {response.content}")
print(response)

View file

@ -16,7 +16,6 @@ from django.views.generic import FormView, TemplateView, View
from core.forms import FormHelper from core.forms import FormHelper
from core.ld import canonicalise from core.ld import canonicalise
from core.signatures import HttpSignature from core.signatures import HttpSignature
from miniq.models import Task
from users.decorators import identity_required from users.decorators import identity_required
from users.models import Domain, Follow, Identity from users.models import Domain, Follow, Identity
from users.shortcuts import by_handle_or_404 from users.shortcuts import by_handle_or_404