Merge branch 'main' into mypy-utils

This commit is contained in:
Joeri de Ruiter 2023-08-22 11:40:48 +02:00 committed by GitHub
commit 567c103e59
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 62 additions and 28 deletions

View file

@ -1,4 +1,6 @@
""" handle reading a csv from calibre """ """ handle reading a csv from calibre """
from typing import Any, Optional
from bookwyrm.models import Shelf from bookwyrm.models import Shelf
from . import Importer from . import Importer
@ -9,7 +11,7 @@ class CalibreImporter(Importer):
service = "Calibre" service = "Calibre"
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
# Add timestamp to row_mappings_guesses for date_added to avoid # Add timestamp to row_mappings_guesses for date_added to avoid
# integrity error # integrity error
row_mappings_guesses = [] row_mappings_guesses = []
@ -23,6 +25,6 @@ class CalibreImporter(Importer):
self.row_mappings_guesses = row_mappings_guesses self.row_mappings_guesses = row_mappings_guesses
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def get_shelf(self, normalized_row): def get_shelf(self, normalized_row: dict[str, Optional[str]]) -> Optional[str]:
# Calibre export does not indicate which shelf to use. Use a default one for now # Calibre export does not indicate which shelf to use. Use a default one for now
return Shelf.TO_READ return Shelf.TO_READ

View file

@ -1,8 +1,10 @@
""" handle reading a csv from an external service, defaults are from Goodreads """ """ handle reading a csv from an external service, defaults are from Goodreads """
import csv import csv
from datetime import timedelta from datetime import timedelta
from typing import Iterable, Optional
from django.utils import timezone from django.utils import timezone
from bookwyrm.models import ImportJob, ImportItem, SiteSettings from bookwyrm.models import ImportJob, ImportItem, SiteSettings, User
class Importer: class Importer:
@ -35,19 +37,26 @@ class Importer:
} }
# pylint: disable=too-many-locals # pylint: disable=too-many-locals
def create_job(self, user, csv_file, include_reviews, privacy): def create_job(
self, user: User, csv_file: Iterable[str], include_reviews: bool, privacy: str
) -> ImportJob:
"""check over a csv and creates a database entry for the job""" """check over a csv and creates a database entry for the job"""
csv_reader = csv.DictReader(csv_file, delimiter=self.delimiter) csv_reader = csv.DictReader(csv_file, delimiter=self.delimiter)
rows = list(csv_reader) rows = list(csv_reader)
if len(rows) < 1: if len(rows) < 1:
raise ValueError("CSV file is empty") raise ValueError("CSV file is empty")
rows = enumerate(rows)
mappings = (
self.create_row_mappings(list(fieldnames))
if (fieldnames := csv_reader.fieldnames)
else {}
)
job = ImportJob.objects.create( job = ImportJob.objects.create(
user=user, user=user,
include_reviews=include_reviews, include_reviews=include_reviews,
privacy=privacy, privacy=privacy,
mappings=self.create_row_mappings(csv_reader.fieldnames), mappings=mappings,
source=self.service, source=self.service,
) )
@ -55,16 +64,20 @@ class Importer:
if enforce_limit and allowed_imports <= 0: if enforce_limit and allowed_imports <= 0:
job.complete_job() job.complete_job()
return job return job
for index, entry in rows: for index, entry in enumerate(rows):
if enforce_limit and index >= allowed_imports: if enforce_limit and index >= allowed_imports:
break break
self.create_item(job, index, entry) self.create_item(job, index, entry)
return job return job
def update_legacy_job(self, job): def update_legacy_job(self, job: ImportJob) -> None:
"""patch up a job that was in the old format""" """patch up a job that was in the old format"""
items = job.items items = job.items
headers = list(items.first().data.keys()) first_item = items.first()
if first_item is None:
return
headers = list(first_item.data.keys())
job.mappings = self.create_row_mappings(headers) job.mappings = self.create_row_mappings(headers)
job.updated_date = timezone.now() job.updated_date = timezone.now()
job.save() job.save()
@ -75,24 +88,24 @@ class Importer:
item.normalized_data = normalized item.normalized_data = normalized
item.save() item.save()
def create_row_mappings(self, headers): def create_row_mappings(self, headers: list[str]) -> dict[str, Optional[str]]:
"""guess what the headers mean""" """guess what the headers mean"""
mappings = {} mappings = {}
for (key, guesses) in self.row_mappings_guesses: for (key, guesses) in self.row_mappings_guesses:
value = [h for h in headers if h.lower() in guesses] values = [h for h in headers if h.lower() in guesses]
value = value[0] if len(value) else None value = values[0] if len(values) else None
if value: if value:
headers.remove(value) headers.remove(value)
mappings[key] = value mappings[key] = value
return mappings return mappings
def create_item(self, job, index, data): def create_item(self, job: ImportJob, index: int, data: dict[str, str]) -> None:
"""creates and saves an import item""" """creates and saves an import item"""
normalized = self.normalize_row(data, job.mappings) normalized = self.normalize_row(data, job.mappings)
normalized["shelf"] = self.get_shelf(normalized) normalized["shelf"] = self.get_shelf(normalized)
ImportItem(job=job, index=index, data=data, normalized_data=normalized).save() ImportItem(job=job, index=index, data=data, normalized_data=normalized).save()
def get_shelf(self, normalized_row): def get_shelf(self, normalized_row: dict[str, Optional[str]]) -> Optional[str]:
"""determine which shelf to use""" """determine which shelf to use"""
shelf_name = normalized_row.get("shelf") shelf_name = normalized_row.get("shelf")
if not shelf_name: if not shelf_name:
@ -103,11 +116,15 @@ class Importer:
] ]
return shelf[0] if shelf else None return shelf[0] if shelf else None
def normalize_row(self, entry, mappings): # pylint: disable=no-self-use # pylint: disable=no-self-use
def normalize_row(
self, entry: dict[str, str], mappings: dict[str, Optional[str]]
) -> dict[str, Optional[str]]:
"""use the dataclass to create the formatted row of data""" """use the dataclass to create the formatted row of data"""
return {k: entry.get(v) for k, v in mappings.items()} return {k: entry.get(v) if v else None for k, v in mappings.items()}
def get_import_limit(self, user): # pylint: disable=no-self-use # pylint: disable=no-self-use
def get_import_limit(self, user: User) -> tuple[int, int]:
"""check if import limit is set and return how many imports are left""" """check if import limit is set and return how many imports are left"""
site_settings = SiteSettings.objects.get() site_settings = SiteSettings.objects.get()
import_size_limit = site_settings.import_size_limit import_size_limit = site_settings.import_size_limit
@ -125,7 +142,9 @@ class Importer:
allowed_imports = import_size_limit - imported_books allowed_imports = import_size_limit - imported_books
return enforce_limit, allowed_imports return enforce_limit, allowed_imports
def create_retry_job(self, user, original_job, items): def create_retry_job(
self, user: User, original_job: ImportJob, items: list[ImportItem]
) -> ImportJob:
"""retry items that didn't import""" """retry items that didn't import"""
job = ImportJob.objects.create( job = ImportJob.objects.create(
user=user, user=user,

View file

@ -1,11 +1,16 @@
""" handle reading a tsv from librarything """ """ handle reading a tsv from librarything """
import re import re
from typing import Optional
from bookwyrm.models import Shelf from bookwyrm.models import Shelf
from . import Importer from . import Importer
def _remove_brackets(value: Optional[str]) -> Optional[str]:
return re.sub(r"\[|\]", "", value) if value else None
class LibrarythingImporter(Importer): class LibrarythingImporter(Importer):
"""csv downloads from librarything""" """csv downloads from librarything"""
@ -13,16 +18,19 @@ class LibrarythingImporter(Importer):
delimiter = "\t" delimiter = "\t"
encoding = "ISO-8859-1" encoding = "ISO-8859-1"
def normalize_row(self, entry, mappings): # pylint: disable=no-self-use def normalize_row(
self, entry: dict[str, str], mappings: dict[str, Optional[str]]
) -> dict[str, Optional[str]]: # pylint: disable=no-self-use
"""use the dataclass to create the formatted row of data""" """use the dataclass to create the formatted row of data"""
remove_brackets = lambda v: re.sub(r"\[|\]", "", v) if v else None normalized = {
normalized = {k: remove_brackets(entry.get(v)) for k, v in mappings.items()} k: _remove_brackets(entry.get(v) if v else None)
isbn_13 = normalized.get("isbn_13") for k, v in mappings.items()
isbn_13 = isbn_13.split(", ") if isbn_13 else [] }
isbn_13 = value.split(", ") if (value := normalized.get("isbn_13")) else []
normalized["isbn_13"] = isbn_13[1] if len(isbn_13) > 1 else None normalized["isbn_13"] = isbn_13[1] if len(isbn_13) > 1 else None
return normalized return normalized
def get_shelf(self, normalized_row): def get_shelf(self, normalized_row: dict[str, Optional[str]]) -> Optional[str]:
if normalized_row["date_finished"]: if normalized_row["date_finished"]:
return Shelf.READ_FINISHED return Shelf.READ_FINISHED
if normalized_row["date_started"]: if normalized_row["date_started"]:

View file

@ -1,4 +1,6 @@
""" handle reading a csv from openlibrary""" """ handle reading a csv from openlibrary"""
from typing import Any
from . import Importer from . import Importer
@ -7,7 +9,7 @@ class OpenLibraryImporter(Importer):
service = "OpenLibrary" service = "OpenLibrary"
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
self.row_mappings_guesses.append(("openlibrary_key", ["edition id"])) self.row_mappings_guesses.append(("openlibrary_key", ["edition id"]))
self.row_mappings_guesses.append(("openlibrary_work_key", ["work id"])) self.row_mappings_guesses.append(("openlibrary_work_key", ["work id"]))
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)

View file

@ -54,10 +54,10 @@ ImportStatuses = [
class ImportJob(models.Model): class ImportJob(models.Model):
"""entry for a specific request for book data import""" """entry for a specific request for book data import"""
user = models.ForeignKey(User, on_delete=models.CASCADE) user: User = models.ForeignKey(User, on_delete=models.CASCADE)
created_date = models.DateTimeField(default=timezone.now) created_date = models.DateTimeField(default=timezone.now)
updated_date = models.DateTimeField(default=timezone.now) updated_date = models.DateTimeField(default=timezone.now)
include_reviews = models.BooleanField(default=True) include_reviews: bool = models.BooleanField(default=True)
mappings = models.JSONField() mappings = models.JSONField()
source = models.CharField(max_length=100) source = models.CharField(max_length=100)
privacy = models.CharField(max_length=255, default="public", choices=PrivacyLevels) privacy = models.CharField(max_length=255, default="public", choices=PrivacyLevels)
@ -76,7 +76,7 @@ class ImportJob(models.Model):
self.save(update_fields=["task_id"]) self.save(update_fields=["task_id"])
def complete_job(self): def complete_job(self) -> None:
"""Report that the job has completed""" """Report that the job has completed"""
self.status = "complete" self.status = "complete"
self.complete = True self.complete = True

View file

@ -16,6 +16,9 @@ ignore_errors = False
[mypy-bookwyrm.utils.*] [mypy-bookwyrm.utils.*]
ignore_errors = False ignore_errors = False
[mypy-bookwyrm.importers.*]
ignore_errors = False
[mypy-celerywyrm.*] [mypy-celerywyrm.*]
ignore_errors = False ignore_errors = False