Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ django_settings_module = config.settings.test
# Django migrations should not produce any errors:
ignore_errors = True

[tool:pytest]
DJANGO_SETTINGS_MODULE = config.settings.test

[coverage:run]
include =
users/*
Expand Down
3 changes: 3 additions & 0 deletions xml_manager/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@ class XML_File_PDF_Generation_Error(Exception):

class XML_File_HTML_Generation_Error(Exception):
pass

class SPS_Package_Validation_Error(Exception):
pass
51 changes: 49 additions & 2 deletions xml_manager/tasks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
import os
import tempfile

from django.conf import settings
from django.contrib.auth import get_user_model
from django.utils import timezone
from django.utils.translation import gettext_lazy as _

from config import celery_app
Expand All @@ -15,9 +17,12 @@
XML_DOCUMENT_UNKNOWN_ERROR,
)
from tracker.models import XMLDocumentEvent
from xml_manager.forms import SPSPackageValidationForm
from xml_manager.models import (
XMLDocument,
XMLDocumentHTML,
SPSPackageValidation,
SPSPackageValidationStatus,
XMLDocument,
XMLDocumentHTML,
XMLDocumentPDF,
)
from xml_manager import exceptions
Expand Down Expand Up @@ -218,3 +223,45 @@ def task_generate_html_file(self, xml_id, user_id=None, username=None):
save=True,
)
return False


@celery_app.task(bind=True)
def task_validate_sps_package(self, validation_pk):
try:
validation = SPSPackageValidation.objects.get(pk=validation_pk)
except SPSPackageValidation.DoesNotExist:
logging.error(f"SPSPackageValidation pk={validation_pk} does not exist.")
return False

validation.status = SPSPackageValidationStatus.RUNNING
validation.save()

try:
zip_path = validation.package_document.file.path
rows = utils.validate_zip(zip_path)

with tempfile.TemporaryDirectory() as tmpdir:
base_name = os.path.splitext(validation.package_document.title)[0]
csv_path = os.path.join(tmpdir, f"{base_name}.validation.csv")
utils.write_csv(rows, csv_path)

if validation.validation_document:
validation.validation_document.delete()
validation.validation_document = None

validation.validation_document = SPSPackageValidationForm.save_wagtail_document_from_path(
csv_path,
title=f"{base_name}.validation.csv",
)

validation.status = SPSPackageValidationStatus.DONE
validation.validated_at = timezone.now()
validation.error_message = ""

except Exception as e:
logging.error(f"Error during SPS package validation pk={validation_pk}: {e}")
validation.status = SPSPackageValidationStatus.ERROR
validation.error_message = str(e)

validation.save()
return True
189 changes: 188 additions & 1 deletion xml_manager/tests/test_sps_package_validation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import csv
import io
import os
import tempfile
import zipfile
from unittest.mock import patch
from unittest.mock import MagicMock, patch

from django.contrib.auth import get_user_model
from django.contrib.messages.storage.fallback import FallbackStorage
Expand All @@ -15,6 +17,8 @@
from xml_manager.exceptions import SPS_Package_Validation_Error
from xml_manager.forms import SPSPackageValidationForm
from xml_manager.models import SPSPackageValidation, SPSPackageValidationStatus
from xml_manager.tasks import task_validate_sps_package
from xml_manager.utils import FIELDNAMES, validate_zip, write_csv
from xml_manager.views import revalidate_sps_package_pk

User = get_user_model()
Expand Down Expand Up @@ -185,3 +189,186 @@ def test_revalidate_returns_404_for_missing_validation(self):
with self.assertRaises(Http404):
revalidate_sps_package_pk(request, pk=99999)

def test_revalidate_dispatches_task(self):
with patch("xml_manager.views.task_validate_sps_package") as mock_task:
request = _request_for_user(self.user)
revalidate_sps_package_pk(request, pk=self.validation.pk)
mock_task.delay.assert_called_once_with(self.validation.pk)


def _mock_validate_xml_content(items):
return iter([{"group": "test-group", "items": items}])


def _mock_xml_with_pre(mock_xmltree=None):
mock_xml = MagicMock()
mock_xml.xmltree = mock_xmltree or MagicMock()
return mock_xml


class ValidateZipTests(TestCase):
def _patch_packtools(self, items):
patcher_xmlwithpre = patch(
"xml_manager.utils.XMLWithPre.create",
return_value=iter([_mock_xml_with_pre()]),
)
patcher_validator = patch(
"xml_manager.utils.xml_validator.validate_xml_content",
return_value=_mock_validate_xml_content(items),
)
patcher_journal = patch(
"xml_manager.utils._extract_journal_data",
return_value={},
)
return patcher_xmlwithpre, patcher_validator, patcher_journal

def test_returns_list_of_dicts(self):
item = {
"title": "t", "parent": "article", "parent_id": None,
"parent_article_type": "research-article", "item": "article-id",
"sub_item": None, "validation_type": "format",
"response": "ERROR", "expected_value": "doi",
"got_value": None, "advice": "add doi",
}
p1, p2, p3 = self._patch_packtools([item])
with p1, p2, p3:
result = validate_zip("fake.zip")
self.assertIsInstance(result, list)
self.assertEqual(len(result), 1)

def test_result_has_expected_keys(self):
item = {
"title": "t", "parent": "article", "parent_id": None,
"parent_article_type": "research-article", "item": "article-id",
"sub_item": "pub-id-type", "validation_type": "format",
"response": "ERROR", "expected_value": "doi",
"got_value": None, "advice": "add doi",
}
p1, p2, p3 = self._patch_packtools([item])
with p1, p2, p3:
result = validate_zip("fake.zip")
for key in FIELDNAMES:
self.assertIn(key, result[0])

def test_attribute_concatenates_item_and_sub_item(self):
item = {
"title": None, "parent": None, "parent_id": None,
"parent_article_type": None, "item": "foo", "sub_item": "bar",
"validation_type": None, "response": "OK",
"expected_value": None, "got_value": None, "advice": None,
}
p1, p2, p3 = self._patch_packtools([item])
with p1, p2, p3:
result = validate_zip("fake.zip")
self.assertEqual(result[0]["attribute"], "foo/bar")

def test_attribute_omits_empty_sub_item(self):
item = {
"title": None, "parent": None, "parent_id": None,
"parent_article_type": None, "item": "foo", "sub_item": None,
"validation_type": None, "response": "OK",
"expected_value": None, "got_value": None, "advice": None,
}
p1, p2, p3 = self._patch_packtools([item])
with p1, p2, p3:
result = validate_zip("fake.zip")
self.assertEqual(result[0]["attribute"], "foo")

def test_handles_none_items_in_group(self):
p1, p2, p3 = self._patch_packtools([None, None])
with p1, p2, p3:
result = validate_zip("fake.zip")
self.assertEqual(result, [])

def test_handles_none_items_list(self):
p1, p2, p3 = self._patch_packtools(None)
with p1, p2, p3:
result = validate_zip("fake.zip")
self.assertEqual(result, [])


class WriteCsvTests(TestCase):
def test_creates_file(self):
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "out.csv")
write_csv([], path)
self.assertTrue(os.path.exists(path))

def test_has_correct_header(self):
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "out.csv")
write_csv([], path)
with open(path, newline="", encoding="utf-8") as f:
reader = csv.DictReader(f)
self.assertEqual(reader.fieldnames, FIELDNAMES)

def test_returns_output_path(self):
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "out.csv")
returned = write_csv([], path)
self.assertEqual(returned, path)


class TaskValidateSpsPackageTests(TestCase):
def setUp(self):
self.user = User.objects.create_user(
username="staff", password="secret", is_staff=True
)
self.package, self.zip_size = make_package_document()
self.validation = SPSPackageValidation.objects.create(
package_document=self.package,
zip_size_bytes=self.zip_size,
validated_by=self.user,
)

def _run_task(self, rows=None):
with patch("xml_manager.tasks.utils.validate_zip", return_value=rows or []):
task_validate_sps_package.delay(self.validation.pk)

def test_status_transitions_to_done(self):
self._run_task()
self.validation.refresh_from_db()
self.assertEqual(self.validation.status, SPSPackageValidationStatus.DONE)

def test_validation_document_is_created(self):
self._run_task()
self.validation.refresh_from_db()
self.assertIsNotNone(self.validation.validation_document)

def test_validated_at_is_set(self):
self._run_task()
self.validation.refresh_from_db()
self.assertIsNotNone(self.validation.validated_at)

def test_error_message_cleared_on_done(self):
self.validation.error_message = "old error"
self.validation.save()
self._run_task()
self.validation.refresh_from_db()
self.assertEqual(self.validation.error_message, "")

def test_status_transitions_to_error_on_failure(self):
with patch("xml_manager.tasks.utils.validate_zip", side_effect=Exception("boom")):
task_validate_sps_package.delay(self.validation.pk)
self.validation.refresh_from_db()
self.assertEqual(self.validation.status, SPSPackageValidationStatus.ERROR)
self.assertEqual(self.validation.error_message, "boom")

def test_existing_validation_document_replaced(self):
old_doc = Document(title="old.csv")
old_doc.file.save(
"old.csv",
SimpleUploadedFile("old.csv", b"a,b\n"),
save=True,
)
old_doc_pk = old_doc.pk
self.validation.validation_document = old_doc
self.validation.save()

self._run_task()
self.validation.refresh_from_db()

self.assertFalse(Document.objects.filter(pk=old_doc_pk).exists())
self.assertIsNotNone(self.validation.validation_document)
self.assertNotEqual(self.validation.validation_document.pk, old_doc_pk)

73 changes: 73 additions & 0 deletions xml_manager/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import csv
import os

from packtools import data_checker
from packtools.sps.models.article_license import ArticleLicense
from packtools.sps.pid_provider.models.journal_meta import JournalID, Publisher, Title
from packtools.sps.pid_provider.xml_sps_lib import XMLWithPre
from packtools.sps.validation import xml_validator
from packtools.sps.formats.pdf.pipeline import docx
from packtools.sps.formats.pdf.utils import file_utils
from packtools.sps.formats.pdf.pipeline.xml import extract_article_main_language
Expand All @@ -26,6 +31,74 @@ def validate_xml_document(xml_file_path, output_root_dir, params):
return path_csv, path_exceptions


FIELDNAMES = [
"group", "title", "parent", "parent_id", "parent_article_type",
"item", "sub_item", "attribute", "validation_type",
"response", "expected_value", "got_value", "advice",
]


def _extract_journal_data(xmltree):
try:
license_code = None
for lic in ArticleLicense(xmltree).licenses:
code = lic.get("code")
if code:
license_code = code
break
return {
"abbrev_journal_title": Title(xmltree).abbreviated_journal_title,
"publisher_name_list": Publisher(xmltree).publishers_names,
"nlm_journal_title": JournalID(xmltree).nlm_ta,
"license_code": license_code,
}
except Exception:
return {}


def validate_zip(zip_path: str) -> list:
rows = []
for xml_with_pre in XMLWithPre.create(path=zip_path):
xmltree = xml_with_pre.xmltree
rules = {"journal_data": _extract_journal_data(xmltree)}
for group_result in xml_validator.validate_xml_content(xmltree, rules):
group = group_result.get("group", "")
try:
items = list(group_result.get("items") or [])
except Exception:
continue
for result in items:
if not result:
continue
item = result.get("item") or ""
sub_item = result.get("sub_item") or ""
attribute = "/".join(filter(None, [item, sub_item]))
rows.append({
"group": group,
"title": result.get("title"),
"parent": result.get("parent"),
"parent_id": result.get("parent_id"),
"parent_article_type": result.get("parent_article_type"),
"item": item,
"sub_item": sub_item,
"attribute": attribute,
"validation_type": result.get("validation_type"),
"response": result.get("response"),
"expected_value": result.get("expected_value"),
"got_value": result.get("got_value"),
"advice": result.get("advice"),
})
return rows


def write_csv(rows: list, output_csv: str) -> str:
with open(output_csv, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=FIELDNAMES)
writer.writeheader()
writer.writerows(rows)
return output_csv


def generate_pdf_for_xml_document(xml_file_path, output_root_dir, params):
if not os.path.exists(output_root_dir):
os.makedirs(output_root_dir)
Expand Down
Loading
Loading