diff --git a/.gitignore b/.gitignore index 0d2e0d2..564da37 100644 --- a/.gitignore +++ b/.gitignore @@ -170,3 +170,9 @@ poetry.lock # Zed editor pyrightconfig.json + +# Files generated by tests +minimal_upload_other_license.json +minimal_upload.json +test_file.txt +toydataset.schema.json \ No newline at end of file diff --git a/easyDataverse/base.py b/easyDataverse/base.py index 0c67ce9..58237e5 100644 --- a/easyDataverse/base.py +++ b/easyDataverse/base.py @@ -28,7 +28,7 @@ class DataverseBase(BaseModel): # ! Overloads def __setattr__(self, name: str, value: Any) -> None: - if name in self.model_fields: + if name in self.__class__.model_fields: self._changed.add(name) return super().__setattr__(name, value) @@ -124,7 +124,7 @@ def dataverse_dict(self) -> Dict: # Get properties and init json_obj json_obj = {} - for attr, field in self.model_fields.items(): + for attr, field in self.__class__.model_fields.items(): if any(name in attr for name in ["add_", "_metadatablock_name"]): # Only necessary for blind fetch continue @@ -189,7 +189,7 @@ def extract_changed(self) -> List[Dict]: changed_fields = [] for name in self._changed: - field = self.model_fields[name] + field = self.__class__.model_fields[name] if self._is_compound(field) and self._is_multiple(field): value = self._process_multiple_compound(getattr(self, name)) @@ -206,7 +206,7 @@ def extract_changed(self) -> List[Dict]: def _add_changed_multiples(self): """Checks whether a compound has multiple changed fields""" - for name, field in self.model_fields.items(): + for name, field in self.__class__.model_fields.items(): if not self._is_compound(field): continue if not self._is_multiple(field): diff --git a/easyDataverse/classgen.py b/easyDataverse/classgen.py index c25d40f..2c1644f 100644 --- a/easyDataverse/classgen.py +++ b/easyDataverse/classgen.py @@ -214,7 +214,10 @@ def generate_add_function(subclass, attribute, name): """ def add_fun_template(self, **kwargs): - getattr(self, attribute).append(subclass(**kwargs)) + self._changed.add(attribute) + obj = subclass(**kwargs) + obj._changed.update(kwargs.keys()) + getattr(self, attribute).append(obj) signature = create_function_signature(subclass) new_func = forge.sign(forge.self, *signature)( @@ -245,6 +248,8 @@ def create_function_signature(subclass) -> List: """ signature = [] for name, dtype in subclass.__annotations__.items(): + if name == "_changed": + continue sig_params = {"name": name, "type": dtype, "interface_name": name} default = subclass.model_fields[name].default diff --git a/pyproject.toml b/pyproject.toml index f95400c..340b5b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "easyDataverse" -version = "0.4.4" +version = "0.4.5" description = "Lightweight Dataverse interface in Python to upload, download and update datasets found in Dataverse instances." authors = ["Jan Range "] license = "MIT License" diff --git a/tests/integration/test_dataset_update.py b/tests/integration/test_dataset_update.py index 4bd6952..b7ec291 100644 --- a/tests/integration/test_dataset_update.py +++ b/tests/integration/test_dataset_update.py @@ -63,6 +63,69 @@ def test_dataset_update( "The updated dataset title does not match the expected title." ) + @pytest.mark.integration + def test_dataset_update_with_multiple_fields( + self, + credentials, + ): + # Arrange + base_url, api_token = credentials + dataverse = Dataverse( + server_url=base_url, + api_token=api_token, + ) + + # Create a dataset + dataset = dataverse.create_dataset() + dataset.citation.title = "My dataset" + dataset.citation.subject = ["Other"] + dataset.citation.add_author(name="John Doe") + dataset.citation.add_ds_description( + value="This is a description of the dataset", + date="2024", + ) + dataset.citation.add_dataset_contact( + name="John Doe", + email="john@doe.com", + ) + + pid = dataset.upload("Root") + + # Act + # Re-fetch the dataset and add other ID + dataset = dataverse.load_dataset(pid) + dataset.citation.add_other_id(agency="DOI", value="10.5072/easy-dataverse") + dataset.update() + + # Re-fetch the dataset to verify the update + url = ( + f"{base_url}/api/datasets/:persistentId/versions/:draft?persistentId={pid}" + ) + + response = httpx.get( + url, + headers={"X-Dataverse-key": api_token}, + ) + + response.raise_for_status() + updated_dataset = response.json() + other_id_field = next( + filter( + lambda x: x["typeName"] == "otherId", + updated_dataset["data"]["metadataBlocks"]["citation"]["fields"], + ), + None, + ) + + # Assert + assert other_id_field is not None, "Other ID field should be present" + assert len(other_id_field["value"]) > 0, "Other ID field should have values" + assert any( + item["otherIdAgency"]["value"] == "DOI" + and item["otherIdValue"]["value"] == "10.5072/easy-dataverse" + for item in other_id_field["value"] + ), "The DOI other ID should be present in the updated dataset" + @staticmethod def sort_citation(dataset: Dict): citation = dataset["datasetVersion"]["metadataBlocks"]["citation"] diff --git a/tests/unit/test_connect.py b/tests/unit/test_connect.py index 8331900..25ef6cb 100644 --- a/tests/unit/test_connect.py +++ b/tests/unit/test_connect.py @@ -1,8 +1,8 @@ from enum import Enum -from typing import List, Optional, Union, get_args +from typing import List, Optional, Set, Union, get_args import pytest -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, PrivateAttr from easyDataverse.classgen import ( camel_to_snake, @@ -451,9 +451,11 @@ class TestClass(BaseModel): name: str value: int = 42 optional: Optional[str] = None + _changed: Set = PrivateAttr(default_factory=set) class ParentClass(BaseModel): to_add_to: List[TestClass] = [] + _changed: Set = PrivateAttr(default_factory=set) # Act result = generate_add_function( @@ -480,6 +482,7 @@ class ParentClass(BaseModel): "name": str, "value": int, "optional": Optional[str], + "_changed": Set, } expected_object = TestClass(