Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,9 @@ poetry.lock

# Zed editor
pyrightconfig.json

# Files generated by tests
minimal_upload_other_license.json
minimal_upload.json
test_file.txt
toydataset.schema.json
8 changes: 4 additions & 4 deletions easyDataverse/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class DataverseBase(BaseModel):

# ! Overloads
def __setattr__(self, name: str, value: Any) -> None:
if name in self.model_fields:
if name in self.__class__.model_fields:
self._changed.add(name)

return super().__setattr__(name, value)
Expand Down Expand Up @@ -124,7 +124,7 @@ def dataverse_dict(self) -> Dict:
# Get properties and init json_obj
json_obj = {}

for attr, field in self.model_fields.items():
for attr, field in self.__class__.model_fields.items():
if any(name in attr for name in ["add_", "_metadatablock_name"]):
# Only necessary for blind fetch
continue
Expand Down Expand Up @@ -189,7 +189,7 @@ def extract_changed(self) -> List[Dict]:
changed_fields = []

for name in self._changed:
field = self.model_fields[name]
field = self.__class__.model_fields[name]

if self._is_compound(field) and self._is_multiple(field):
value = self._process_multiple_compound(getattr(self, name))
Expand All @@ -206,7 +206,7 @@ def extract_changed(self) -> List[Dict]:
def _add_changed_multiples(self):
"""Checks whether a compound has multiple changed fields"""

for name, field in self.model_fields.items():
for name, field in self.__class__.model_fields.items():
if not self._is_compound(field):
continue
if not self._is_multiple(field):
Expand Down
7 changes: 6 additions & 1 deletion easyDataverse/classgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,10 @@ def generate_add_function(subclass, attribute, name):
"""

def add_fun_template(self, **kwargs):
getattr(self, attribute).append(subclass(**kwargs))
self._changed.add(attribute)
obj = subclass(**kwargs)
obj._changed.update(kwargs.keys())
getattr(self, attribute).append(obj)

signature = create_function_signature(subclass)
new_func = forge.sign(forge.self, *signature)(
Expand Down Expand Up @@ -245,6 +248,8 @@ def create_function_signature(subclass) -> List:
"""
signature = []
for name, dtype in subclass.__annotations__.items():
if name == "_changed":
continue
sig_params = {"name": name, "type": dtype, "interface_name": name}

default = subclass.model_fields[name].default
Expand Down
63 changes: 63 additions & 0 deletions tests/integration/test_dataset_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,69 @@ def test_dataset_update(
"The updated dataset title does not match the expected title."
)

@pytest.mark.integration
def test_dataset_update_with_multiple_fields(
self,
credentials,
):
# Arrange
base_url, api_token = credentials
dataverse = Dataverse(
server_url=base_url,
api_token=api_token,
)

# Create a dataset
dataset = dataverse.create_dataset()
dataset.citation.title = "My dataset"
dataset.citation.subject = ["Other"]
dataset.citation.add_author(name="John Doe")
dataset.citation.add_ds_description(
value="This is a description of the dataset",
date="2024",
)
dataset.citation.add_dataset_contact(
name="John Doe",
email="[email protected]",
)

pid = dataset.upload("Root")

# Act
# Re-fetch the dataset and add other ID
dataset = dataverse.load_dataset(pid)
dataset.citation.add_other_id(agency="DOI", value="10.5072/easy-dataverse")
dataset.update()

# Re-fetch the dataset to verify the update
url = (
f"{base_url}/api/datasets/:persistentId/versions/:draft?persistentId={pid}"
)

response = httpx.get(
url,
headers={"X-Dataverse-key": api_token},
)

response.raise_for_status()
updated_dataset = response.json()
other_id_field = next(
filter(
lambda x: x["typeName"] == "otherId",
updated_dataset["data"]["metadataBlocks"]["citation"]["fields"],
),
None,
)

# Assert
assert other_id_field is not None, "Other ID field should be present"
assert len(other_id_field["value"]) > 0, "Other ID field should have values"
assert any(
item["otherIdAgency"]["value"] == "DOI"
and item["otherIdValue"]["value"] == "10.5072/easy-dataverse"
for item in other_id_field["value"]
), "The DOI other ID should be present in the updated dataset"

@staticmethod
def sort_citation(dataset: Dict):
citation = dataset["datasetVersion"]["metadataBlocks"]["citation"]
Expand Down
7 changes: 5 additions & 2 deletions tests/unit/test_connect.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from enum import Enum
from typing import List, Optional, Union, get_args
from typing import List, Optional, Set, Union, get_args

import pytest
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PrivateAttr

from easyDataverse.classgen import (
camel_to_snake,
Expand Down Expand Up @@ -451,9 +451,11 @@ class TestClass(BaseModel):
name: str
value: int = 42
optional: Optional[str] = None
_changed: Set = PrivateAttr(default_factory=set)

class ParentClass(BaseModel):
to_add_to: List[TestClass] = []
_changed: Set = PrivateAttr(default_factory=set)

# Act
result = generate_add_function(
Expand All @@ -480,6 +482,7 @@ class ParentClass(BaseModel):
"name": str,
"value": int,
"optional": Optional[str],
"_changed": Set,
}

expected_object = TestClass(
Expand Down
Loading