Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,9 @@ poetry.lock

# Zed editor
pyrightconfig.json

# Files generated by tests
minimal_upload_other_license.json
minimal_upload.json
test_file.txt
toydataset.schema.json
8 changes: 4 additions & 4 deletions easyDataverse/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class DataverseBase(BaseModel):

# ! Overloads
def __setattr__(self, name: str, value: Any) -> None:
if name in self.model_fields:
if name in self.__class__.model_fields:
self._changed.add(name)

return super().__setattr__(name, value)
Expand Down Expand Up @@ -124,7 +124,7 @@ def dataverse_dict(self) -> Dict:
# Get properties and init json_obj
json_obj = {}

for attr, field in self.model_fields.items():
for attr, field in self.__class__.model_fields.items():
if any(name in attr for name in ["add_", "_metadatablock_name"]):
# Only necessary for blind fetch
continue
Expand Down Expand Up @@ -189,7 +189,7 @@ def extract_changed(self) -> List[Dict]:
changed_fields = []

for name in self._changed:
field = self.model_fields[name]
field = self.__class__.model_fields[name]

if self._is_compound(field) and self._is_multiple(field):
value = self._process_multiple_compound(getattr(self, name))
Expand All @@ -206,7 +206,7 @@ def extract_changed(self) -> List[Dict]:
def _add_changed_multiples(self):
"""Checks whether a compound has multiple changed fields"""

for name, field in self.model_fields.items():
for name, field in self.__class__.model_fields.items():
if not self._is_compound(field):
continue
if not self._is_multiple(field):
Expand Down
7 changes: 6 additions & 1 deletion easyDataverse/classgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,10 @@ def generate_add_function(subclass, attribute, name):
"""

def add_fun_template(self, **kwargs):
getattr(self, attribute).append(subclass(**kwargs))
self._changed.add(attribute)
obj = subclass(**kwargs)
obj._changed.update(kwargs.keys())
getattr(self, attribute).append(obj)

signature = create_function_signature(subclass)
new_func = forge.sign(forge.self, *signature)(
Expand Down Expand Up @@ -245,6 +248,8 @@ def create_function_signature(subclass) -> List:
"""
signature = []
for name, dtype in subclass.__annotations__.items():
if name == "_changed":
continue
sig_params = {"name": name, "type": dtype, "interface_name": name}

default = subclass.model_fields[name].default
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "easyDataverse"
version = "0.4.4"
version = "0.4.5"
description = "Lightweight Dataverse interface in Python to upload, download and update datasets found in Dataverse instances."
authors = ["Jan Range <[email protected]>"]
license = "MIT License"
Expand Down
63 changes: 63 additions & 0 deletions tests/integration/test_dataset_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,69 @@ def test_dataset_update(
"The updated dataset title does not match the expected title."
)

@pytest.mark.integration
def test_dataset_update_with_multiple_fields(
self,
credentials,
):
# Arrange
base_url, api_token = credentials
dataverse = Dataverse(
server_url=base_url,
api_token=api_token,
)

# Create a dataset
dataset = dataverse.create_dataset()
dataset.citation.title = "My dataset"
dataset.citation.subject = ["Other"]
dataset.citation.add_author(name="John Doe")
dataset.citation.add_ds_description(
value="This is a description of the dataset",
date="2024",
)
dataset.citation.add_dataset_contact(
name="John Doe",
email="[email protected]",
)

pid = dataset.upload("Root")

# Act
# Re-fetch the dataset and add other ID
dataset = dataverse.load_dataset(pid)
dataset.citation.add_other_id(agency="DOI", value="10.5072/easy-dataverse")
dataset.update()

# Re-fetch the dataset to verify the update
url = (
f"{base_url}/api/datasets/:persistentId/versions/:draft?persistentId={pid}"
)

response = httpx.get(
url,
headers={"X-Dataverse-key": api_token},
)

response.raise_for_status()
updated_dataset = response.json()
other_id_field = next(
filter(
lambda x: x["typeName"] == "otherId",
updated_dataset["data"]["metadataBlocks"]["citation"]["fields"],
),
None,
)

# Assert
assert other_id_field is not None, "Other ID field should be present"
assert len(other_id_field["value"]) > 0, "Other ID field should have values"
assert any(
item["otherIdAgency"]["value"] == "DOI"
and item["otherIdValue"]["value"] == "10.5072/easy-dataverse"
for item in other_id_field["value"]
), "The DOI other ID should be present in the updated dataset"

@staticmethod
def sort_citation(dataset: Dict):
citation = dataset["datasetVersion"]["metadataBlocks"]["citation"]
Expand Down
7 changes: 5 additions & 2 deletions tests/unit/test_connect.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from enum import Enum
from typing import List, Optional, Union, get_args
from typing import List, Optional, Set, Union, get_args

import pytest
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PrivateAttr

from easyDataverse.classgen import (
camel_to_snake,
Expand Down Expand Up @@ -451,9 +451,11 @@ class TestClass(BaseModel):
name: str
value: int = 42
optional: Optional[str] = None
_changed: Set = PrivateAttr(default_factory=set)

class ParentClass(BaseModel):
to_add_to: List[TestClass] = []
_changed: Set = PrivateAttr(default_factory=set)

# Act
result = generate_add_function(
Expand All @@ -480,6 +482,7 @@ class ParentClass(BaseModel):
"name": str,
"value": int,
"optional": Optional[str],
"_changed": Set,
}

expected_object = TestClass(
Expand Down