Skip to content
15 changes: 6 additions & 9 deletions ami/base/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class ProjectMixin:
request: rest_framework.request.Request
kwargs: dict

def get_active_project(self) -> Project:
def get_active_project(self) -> Project | None:
from ami.base.serializers import SingleParamSerializer

project_id = None
Expand All @@ -29,13 +29,10 @@ def get_active_project(self) -> Project:

# If not in URL, try query parameters
if not project_id:
if self.require_project:
project_id = SingleParamSerializer[int].clean(
param_name="project_id",
field=serializers.IntegerField(required=True, min_value=0),
data=self.request.query_params,
)
else:
project_id = self.request.query_params.get("project_id") # No validation
project_id = SingleParamSerializer[int].clean(
param_name="project_id",
field=serializers.IntegerField(required=self.require_project, min_value=0),
data=self.request.query_params,
)

return get_object_or_404(Project, id=project_id) if project_id else None
52 changes: 32 additions & 20 deletions ami/main/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from django.db.models import QuerySet
from guardian.shortcuts import get_perms
from rest_framework import serializers
from rest_framework.fields import Field
from rest_framework.request import Request

from ami.base.fields import DateStringField
Expand Down Expand Up @@ -449,7 +450,34 @@ def get_occurrences(self, obj):
)


class TaxonCoverImageField(Field):
"""
A custom field for retrieving a taxon's cover image URL.

This field handles the logic for determining the appropriate cover image URL:
1. Uses the taxon's cover_image_url if available
2. Falls back to the best_detection_image_path (added by QuerySet annotation)
3. Returns None if no image is available
"""

def __init__(self, **kwargs):
kwargs["source"] = "*" # Use the entire object as the source
kwargs["read_only"] = True
super().__init__(**kwargs)

def to_representation(self, obj):
if obj.cover_image_url:
return obj.cover_image_url
elif hasattr(obj, "best_detection_image_path") and obj.best_detection_image_path:
# This attribute is added by a QuerySet annotation
return get_media_url(obj.best_detection_image_path)
else:
return None


class TaxonNoParentNestedSerializer(DefaultSerializer):
cover_image_url = TaxonCoverImageField()

class Meta:
model = Taxon
fields = [
Expand Down Expand Up @@ -489,6 +517,8 @@ class Meta(TaxonNoParentNestedSerializer.Meta):


class TaxonSearchResultSerializer(TaxonNestedSerializer):
cover_image_url = TaxonCoverImageField()

class Meta:
model = Taxon
fields = [
Expand Down Expand Up @@ -521,7 +551,7 @@ class TaxonListSerializer(DefaultSerializer):
occurrences = serializers.SerializerMethodField()
parents = TaxonParentSerializer(many=True, read_only=True, source="parents_json")
parent_id = serializers.PrimaryKeyRelatedField(queryset=Taxon.objects.all(), source="parent")
cover_image_url = serializers.SerializerMethodField()
cover_image_url = TaxonCoverImageField()
tags = serializers.SerializerMethodField()

def get_tags(self, obj):
Expand Down Expand Up @@ -565,15 +595,6 @@ def get_occurrences(self, obj):
params=params,
)

def get_cover_image_url(self, obj):
if obj.cover_image_url:
return obj.cover_image_url
elif hasattr(obj, "best_detection_image_path") and obj.best_detection_image_path:
# This attribute is added by an QuerySet annotation
return get_media_url(obj.best_detection_image_path)
else:
return None


class TaxaListSerializer(serializers.ModelSerializer):
taxa = serializers.SerializerMethodField()
Expand Down Expand Up @@ -758,7 +779,7 @@ class TaxonSerializer(DefaultSerializer):
parent = TaxonNoParentNestedSerializer(read_only=True)
parent_id = serializers.PrimaryKeyRelatedField(queryset=Taxon.objects.all(), source="parent", write_only=True)
parents = TaxonParentSerializer(many=True, read_only=True, source="parents_json")
cover_image_url = serializers.SerializerMethodField()
cover_image_url = TaxonCoverImageField()
tags = serializers.SerializerMethodField()

def get_tags(self, obj):
Expand Down Expand Up @@ -788,15 +809,6 @@ class Meta:
"unknown_species",
]

def get_cover_image_url(self, obj):
if obj.cover_image_url:
return obj.cover_image_url
elif hasattr(obj, "best_detection_image_path") and obj.best_detection_image_path:
# This attribute is added by an QuerySet annotation
return get_media_url(obj.best_detection_image_path)
else:
return None


class CaptureOccurrenceSerializer(DefaultSerializer):
determination = TaxonNoParentNestedSerializer(read_only=True)
Expand Down
18 changes: 13 additions & 5 deletions ami/main/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,8 @@ class TaxonViewSet(DefaultViewSet, ProjectMixin):
API endpoint that allows taxa to be viewed or edited.
"""

require_project = True # Taxa are always associated with a project

queryset = Taxon.objects.all().defer("notes")
serializer_class = TaxonSerializer
filter_backends = DefaultViewSetMixin.filter_backends + [
Expand Down Expand Up @@ -1278,8 +1280,7 @@ def suggest(self, request):

if query and len(query) >= min_query_length:
taxa = (
Taxon.objects.filter(active=True)
# .select_related("parent")
self.get_queryset()
.filter(models.Q(name__icontains=query) | models.Q(search_names__icontains=query))
.annotate(
# Calculate similarity for the name field
Expand Down Expand Up @@ -1366,11 +1367,16 @@ def get_queryset(self) -> QuerySet:
and add extra data about the occurrences.
Otherwise return all taxa that are active.
"""
qs = super().get_queryset()
qs = super().get_queryset().filter(active=True)
project = self.get_active_project()
qs = self.attach_tags_by_project(qs, project)

if project:
# Filter by project, but also include global taxa
# @TODO IMPORTANT: if taxa belongs to a project, ensure user has permission to view it
qs = qs.filter(models.Q(projects=project) | models.Q(projects__isnull=True))

qs = self.attach_tags_by_project(qs, project)

include_unobserved = True # Show detail views for unobserved taxa instead of 404
# @TODO move to a QuerySet manager
qs = qs.annotate(
Expand Down Expand Up @@ -1512,7 +1518,9 @@ def get_queryset(self):
qs = super().get_queryset()
project = self.get_active_project()
if project:
return qs.filter(projects=project)
# Filter by project, but also include global taxa
# @TODO IMPORTANT: if taxa belongs to a project, ensure user has permission to view it
return qs.filter(models.Q(projects=project) | models.Q(projects__isnull=True))
return qs

serializer_class = TaxaListSerializer
Expand Down
40 changes: 40 additions & 0 deletions ami/main/migrations/0068_allow_taxa_without_project.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Generated by Django 4.2.10 on 2025-05-17 12:13

from django.db import migrations, models


def clear_projects_for_existing_taxa(apps, schema_editor):
"""
Clear the projects field for all existing Taxon instances.
Previously the admin required a project for taxa, but this was an error and
the current taxon-project assignments are mostly random.
This migration resets the project taxa to an acurate state.
"""
Taxon = apps.get_model("main", "Taxon")

# Clear projects for all Taxon instances
for taxon in Taxon.objects.all():
taxon.projects.clear()


class Migration(migrations.Migration):
dependencies = [
("main", "0067_tag_taxon_tags"),
]

operations = [
migrations.AlterField(
model_name="taxalist",
name="projects",
field=models.ManyToManyField(blank=True, related_name="taxa_lists", to="main.project"),
),
migrations.AlterField(
model_name="taxon",
name="projects",
field=models.ManyToManyField(blank=True, related_name="taxa", to="main.project"),
),
migrations.RunPython(
clear_projects_for_existing_taxa,
reverse_code=migrations.RunPython.noop,
),
]
4 changes: 2 additions & 2 deletions ami/main/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2899,7 +2899,7 @@ class Taxon(BaseModel):

notes = models.TextField(blank=True)

projects = models.ManyToManyField("Project", related_name="taxa")
projects = models.ManyToManyField("Project", related_name="taxa", blank=True)
direct_children: models.QuerySet["Taxon"]
occurrences: models.QuerySet[Occurrence]
classifications: models.QuerySet["Classification"]
Expand Down Expand Up @@ -3163,7 +3163,7 @@ class TaxaList(BaseModel):
description = models.TextField(blank=True)

taxa = models.ManyToManyField(Taxon, related_name="lists")
projects = models.ManyToManyField("Project", related_name="taxa_lists")
projects = models.ManyToManyField("Project", related_name="taxa_lists", blank=True)

class Meta:
ordering = ["-created_at"]
Expand Down
4 changes: 2 additions & 2 deletions ami/main/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def no_test_project_species_list(self):
response = self.client.get(
"/api/v2/taxa/",
{
"project": self.project_one.pk,
"project_id": self.project_one.pk,
"rank": TaxonRank.SPECIES.name,
},
)
Expand Down Expand Up @@ -671,7 +671,7 @@ def test_taxon_detail(self):
taxon = Taxon.objects.last()
assert taxon is not None
print("Testing taxon", taxon, taxon.pk)
response = self.client.get(f"/api/v2/taxa/{taxon.pk}/")
response = self.client.get(f"/api/v2/taxa/{taxon.pk}/?project_id={self.project_one.pk}")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json()["name"], taxon.name)

Expand Down
1 change: 1 addition & 0 deletions ui/src/components/filtering/filters/taxon-filter.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ export const TaxonFilter = ({ value, onAdd, onClear }: FilterProps) => {
>
<TaxonSearch
taxon={taxon}
projectId={projectId}
onTaxonChange={(taxon) => {
if (taxon) {
onAdd(taxon.id)
Expand Down
4 changes: 3 additions & 1 deletion ui/src/components/taxon-search/taxon-search.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ import { useTaxonSearch } from './useTaxonSearch'
export const TaxonSearch = ({
taxon,
onTaxonChange,
projectId,
}: {
taxon?: Taxon
onTaxonChange: (taxon?: Taxon) => void
projectId?: string
}) => {
const [searchString, setSearchString] = useState('')
const debouncedSearchString = useDebounce(searchString, 200)
const { data, isLoading } = useTaxonSearch(debouncedSearchString)
const { data, isLoading } = useTaxonSearch(debouncedSearchString, projectId)

const tree = useMemo(() => {
if (!data?.length) {
Expand Down
3 changes: 3 additions & 0 deletions ui/src/components/taxon-search/taxon-select.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ export const TaxonSelect = ({
onTaxonChange,
taxon,
triggerLabel,
projectId,
}: {
isLoading?: boolean
onTaxonChange: (taxon?: Taxon) => void
taxon?: Taxon
triggerLabel: string
projectId?: string
}) => {
const [open, setOpen] = useState(false)

Expand Down Expand Up @@ -43,6 +45,7 @@ export const TaxonSelect = ({
>
<TaxonSearch
taxon={taxon}
projectId={projectId}
onTaxonChange={(taxon) => {
onTaxonChange(taxon)
setOpen(false)
Expand Down
6 changes: 4 additions & 2 deletions ui/src/components/taxon-search/useTaxonSearch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ const convertServerResults = (result: ServerTaxon[]): Taxon[] => {
return _.unionWith(taxa, (t1, t2) => t1.id === t2.id)
}

export const useTaxonSearch = (searchString: string) => {
export const useTaxonSearch = (searchString: string, projectId?: string) => {
const [data, setData] = useState<Taxon[]>()
const [isLoading, setIsLoading] = useState<boolean>()
const [error, setError] = useState<Error>()
const fetchUrl = searchString.length
? `${API_URL}/taxa/suggest/?q=${searchString}&limit=${MAX_NUM_RESULTS}`
? `${API_URL}/taxa/suggest/?q=${searchString}&limit=${MAX_NUM_RESULTS}${
projectId ? `&project_id=${projectId}` : ''
}`
: undefined

useEffect(() => {
Expand Down