Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the role filtering parameter to be a tuple #7020

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions lms/services/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,7 @@ def get_assignment_roster(
# If rostering is enabled and we do have the data, use it
query = self._roster_service.get_assignment_roster(
assignment,
role_scope=RoleScope.COURSE,
role_type=RoleType.LEARNER,
include_role=(RoleScope.COURSE, RoleType.LEARNER),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the changes is this, from two parameters to a tuple.

I might create a type alias when we add exclude_role.

h_userids=h_userids,
)

Expand All @@ -229,8 +228,7 @@ def get_assignment_roster(
roster_last_updated = None
# Always fallback to fetch users that have launched the assignment at some point
query = self._user_service.get_users_for_assignment(
role_scope=RoleScope.COURSE,
role_type=RoleType.LEARNER,
include_role=(RoleScope.COURSE, RoleType.LEARNER),
assignment_id=assignment.id,
h_userids=h_userids,
# For launch data we always add the "active" column as true for compatibility with the roster query.
Expand All @@ -249,16 +247,14 @@ def get_course_roster(
# If rostering is enabled and we do have the data, use it
query = self._roster_service.get_course_roster(
lms_course,
role_scope=RoleScope.COURSE,
role_type=RoleType.LEARNER,
include_role=(RoleScope.COURSE, RoleType.LEARNER),
h_userids=h_userids,
)

else:
# Always fallback to fetch users that have launched the assignment at some point
query = self._user_service.get_users_for_course(
role_scope=RoleScope.COURSE,
role_type=RoleType.LEARNER,
include_role=(RoleScope.COURSE, RoleType.LEARNER),
lms_course=lms_course,
h_userids=h_userids,
# For launch data we always add the "active" column as true for compatibility with the roster query.
Expand All @@ -280,16 +276,14 @@ def get_segments_roster(
# If rostering is enabled and we do have the data, use it
query = self._roster_service.get_segments_roster(
segments,
role_scope=RoleScope.COURSE,
role_type=RoleType.LEARNER,
include_role=(RoleScope.COURSE, RoleType.LEARNER),
h_userids=h_userids,
)

else:
# Always fallback to fetch users that have launched the assignment at some point
query = self._user_service.get_users_for_segments(
role_scope=RoleScope.COURSE,
role_type=RoleType.LEARNER,
include_role=(RoleScope.COURSE, RoleType.LEARNER),
segment_ids=[segment.id for segment in segments],
h_userids=h_userids,
# For launch data we always add the "active" column as true for compatibility with the roster query.
Expand Down
21 changes: 9 additions & 12 deletions lms/services/roster.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ def course_roster_last_updated(self, course: LMSCourse) -> datetime | None:
def get_assignment_roster(
self,
assignment: Assignment,
role_scope: RoleScope | None = None,
role_type: RoleType | None = None,
include_role: tuple[RoleScope | None, RoleType | None] = (None, None),
h_userids: list[str] | None = None,
) -> Select[tuple[LMSUser, bool]]:
"""Get the roster information for a course from our DB."""
Expand All @@ -95,13 +94,12 @@ def get_assignment_roster(
.where(AssignmentRoster.assignment_id == assignment.id)
).distinct()

return self._get_roster(roster_query, role_scope, role_type, h_userids)
return self._get_roster(roster_query, include_role, h_userids)

def get_segments_roster(
self,
segments: list[LMSSegment],
role_scope: RoleScope | None = None,
role_type: RoleType | None = None,
include_role: tuple[RoleScope | None, RoleType | None] = (None, None),
h_userids: list[str] | None = None,
) -> Select[tuple[LMSUser, bool]]:
"""Get the roster information for a segment from our DB."""
Expand All @@ -113,13 +111,12 @@ def get_segments_roster(
.where(LMSSegmentRoster.lms_segment_id.in_([s.id for s in segments]))
).distinct()

return self._get_roster(roster_query, role_scope, role_type, h_userids)
return self._get_roster(roster_query, include_role, h_userids)

def get_course_roster(
self,
lms_course: LMSCourse,
role_scope: RoleScope | None = None,
role_type: RoleType | None = None,
include_role: tuple[RoleScope | None, RoleType | None] = (None, None),
h_userids: list[str] | None = None,
) -> Select[tuple[LMSUser, bool]]:
"""Get the roster information for a course from our DB."""
Expand All @@ -130,19 +127,19 @@ def get_course_roster(
.where(CourseRoster.lms_course_id == lms_course.id)
).distinct()

return self._get_roster(roster_query, role_scope, role_type, h_userids)
return self._get_roster(roster_query, include_role, h_userids)

def _get_roster(
self,
roster_query,
role_scope: RoleScope | None = None,
role_type: RoleType | None = None,
include_role: tuple[RoleScope | None, RoleType | None] = (None, None),
h_userids: list[str] | None = None,
) -> Select[tuple[LMSUser, bool]]:
"""Filter a roster query by role and h_userids.

Helper function for the get_*_roster methods.
"""
role_scope, role_type = include_role
if role_scope:
roster_query = roster_query.where(LTIRole.scope == role_scope)

Expand Down Expand Up @@ -389,7 +386,7 @@ def fetch_canvas_sections_roster(self, lms_course: LMSCourse) -> None:
Sections are different than other rosters:
- We fetch them via the proprietary Canvas API, not the LTI Names and Roles endpoint.

- Due to the return value of that API we don't fetch rosters for indivual sections,
- Due to the return value of that API we don't fetch rosters for individual sections,
but for all sections of one course at once

- The return value of the API doesn't include enough information to create unseen users
Expand Down
60 changes: 29 additions & 31 deletions lms/services/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,7 @@ def _user_search_query(self, application_instance_id, user_id) -> Select:

def get_users_for_assignment(
self,
role_scope: RoleScope,
role_type: RoleType,
include_role: tuple[RoleScope, RoleType],
assignment_id: int,
h_userids: list[str] | None = None,
) -> Select[tuple[LMSUser]]:
Expand All @@ -175,10 +174,8 @@ def get_users_for_assignment(
)
.where(
LMSUserAssignmentMembership.assignment_id == assignment_id,
LMSUserAssignmentMembership.lti_role_id.in_(
select(LTIRole.id).where(
LTIRole.scope == role_scope, LTIRole.type == role_type
)
self._filter_membership_by_role_clause(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a small helper function here to avoid a bit duplication.

Again, this will make more sense once we add exclude_role, the helper function will do the filtering for both.

LMSUserAssignmentMembership, include_role
),
)
)
Expand All @@ -189,8 +186,7 @@ def get_users_for_assignment(

def get_users_for_course(
self,
role_scope: RoleScope,
role_type: RoleType,
include_role: tuple[RoleScope, RoleType],
lms_course: LMSCourse,
h_userids: list[str] | None = None,
) -> Select[tuple[LMSUser]]:
Expand All @@ -205,10 +201,8 @@ def get_users_for_course(
.join(LMSCourse, LMSCourse.id == LMSCourseMembership.lms_course_id)
.where(
LMSCourseMembership.lms_course_id == lms_course.id,
LMSCourseMembership.lti_role_id.in_(
select(LTIRole.id).where(
LTIRole.scope == role_scope, LTIRole.type == role_type
)
self._filter_membership_by_role_clause(
LMSCourseMembership, include_role
),
)
)
Expand All @@ -219,14 +213,13 @@ def get_users_for_course(

def get_users_for_segments(
self,
role_scope: RoleScope,
role_type: RoleType,
include_role: tuple[RoleScope, RoleType],
segment_ids: list[int],
h_userids: list[str] | None = None,
) -> Select[tuple[LMSUser]]:
"""Get the users that belong to a list of segment.

This method doesn't use roste data, just launches.
This method doesn't use roster data, just launches.
"""
query = (
select(LMSUser)
Expand All @@ -237,10 +230,8 @@ def get_users_for_segments(
)
.where(
LMSSegmentMembership.lms_segment_id.in_(segment_ids),
LMSSegmentMembership.lti_role_id.in_(
select(LTIRole.id).where(
LTIRole.scope == role_scope, LTIRole.type == role_type
)
self._filter_membership_by_role_clause(
LMSSegmentMembership, include_role
),
)
)
Expand All @@ -249,10 +240,9 @@ def get_users_for_segments(

return query.order_by(LMSUser.display_name, LMSUser.id)

def get_users_for_organization( # noqa: PLR0913
def get_users_for_organization(
self,
role_scope: RoleScope,
role_type: RoleType,
include_role: tuple[RoleScope, RoleType],
course_ids: list[int] | None = None,
instructor_h_userid: str | None = None,
admin_organization_ids: list[int] | None = None,
Expand All @@ -273,11 +263,9 @@ def get_users_for_organization( # noqa: PLR0913
)
.join(candidate_courses, candidate_courses.c[0] == Grouping.id)
.where(
LMSCourseMembership.lti_role_id.in_(
select(LTIRole.id).where(
LTIRole.scope == role_scope, LTIRole.type == role_type
)
)
self._filter_membership_by_role_clause(
LMSCourseMembership, include_role
),
)
)

Expand All @@ -288,8 +276,7 @@ def get_users_for_organization( # noqa: PLR0913

def get_users( # noqa: PLR0913
self,
role_scope: RoleScope,
role_type: RoleType,
include_role: tuple[RoleScope, RoleType],
instructor_h_userid: str | None = None,
admin_organization_ids: list[int] | None = None,
course_ids: list[int] | None = None,
Expand All @@ -310,8 +297,7 @@ def get_users( # noqa: PLR0913
:param segment_authority_provided_ids: return only users that belong these segments.
"""
query = self.get_users_for_organization(
role_scope=role_scope,
role_type=role_type,
include_role=include_role,
instructor_h_userid=instructor_h_userid,
admin_organization_ids=admin_organization_ids,
h_userids=h_userids,
Expand Down Expand Up @@ -346,6 +332,18 @@ def get_users( # noqa: PLR0913

return query.order_by(LMSUser.display_name, LMSUser.id)

def _filter_membership_by_role_clause(
self,
MembershipModel, # noqa: N803
include_role: tuple[RoleScope, RoleType],
):
role_scope, role_type = include_role
return MembershipModel.lti_role_id.in_(
select(LTIRole.id).where(
LTIRole.scope == role_scope, LTIRole.type == role_type
)
)


def factory(_context, request):
"""Service factory for the UserService."""
Expand Down
6 changes: 2 additions & 4 deletions lms/views/dashboard/api/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,7 @@ def _students_query(
# Full organization fetch
if not course_ids and not assignment_ids and not segment_authority_provided_ids:
return None, self.user_service.get_users_for_organization(
role_scope=RoleScope.COURSE,
role_type=RoleType.LEARNER,
include_role=(RoleScope.COURSE, RoleType.LEARNER),
h_userids=h_userids,
# Users the current user has access to see
instructor_h_userid=self.request.user.h_userid
Expand All @@ -276,8 +275,7 @@ def _students_query(
).add_columns(true())

return None, self.user_service.get_users(
role_scope=RoleScope.COURSE,
role_type=RoleType.LEARNER,
include_role=(RoleScope.COURSE, RoleType.LEARNER),
course_ids=self.request.parsed_params.get("course_ids"),
assignment_ids=assignment_ids,
# Users the current user has access to see
Expand Down
18 changes: 6 additions & 12 deletions tests/unit/lms/services/dashboard_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,7 @@ def test_get_assignment_roster(

if not roster_available:
user_service.get_users_for_assignment.assert_called_once_with(
role_scope=RoleScope.COURSE,
role_type=RoleType.LEARNER,
include_role=(RoleScope.COURSE, RoleType.LEARNER),
assignment_id=assignment.id,
h_userids=sentinel.h_userids,
)
Expand All @@ -298,8 +297,7 @@ def test_get_assignment_roster(
else:
roster_service.get_assignment_roster.assert_called_once_with(
assignment,
role_scope=RoleScope.COURSE,
role_type=RoleType.LEARNER,
include_role=(RoleScope.COURSE, RoleType.LEARNER),
h_userids=sentinel.h_userids,
)
assert (
Expand Down Expand Up @@ -330,8 +328,7 @@ def test_get_course_roster(

if not roster_available:
user_service.get_users_for_course.assert_called_once_with(
role_scope=RoleScope.COURSE,
role_type=RoleType.LEARNER,
include_role=(RoleScope.COURSE, RoleType.LEARNER),
lms_course=lms_course,
h_userids=sentinel.h_userids,
)
Expand All @@ -343,8 +340,7 @@ def test_get_course_roster(
else:
roster_service.get_course_roster.assert_called_once_with(
lms_course,
role_scope=RoleScope.COURSE,
role_type=RoleType.LEARNER,
include_role=(RoleScope.COURSE, RoleType.LEARNER),
h_userids=sentinel.h_userids,
)
assert (
Expand All @@ -367,8 +363,7 @@ def test_get_segment_roster(

if not roster_available:
user_service.get_users_for_segments.assert_called_once_with(
role_scope=RoleScope.COURSE,
role_type=RoleType.LEARNER,
include_role=(RoleScope.COURSE, RoleType.LEARNER),
segment_ids=[segment.id],
h_userids=sentinel.h_userids,
)
Expand All @@ -380,8 +375,7 @@ def test_get_segment_roster(
else:
roster_service.get_segments_roster.assert_called_once_with(
[segment],
role_scope=RoleScope.COURSE,
role_type=RoleType.LEARNER,
include_role=(RoleScope.COURSE, RoleType.LEARNER),
h_userids=sentinel.h_userids,
)
assert (
Expand Down
18 changes: 12 additions & 6 deletions tests/unit/lms/services/roster_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,10 @@ def test_get_course_roster(
result = db_session.execute(
svc.get_course_roster(
lms_course,
role_scope=lti_role.scope if with_role_scope else None,
role_type=lti_role.type if with_role_type else None,
include_role=(
lti_role.scope if with_role_scope else None,
lti_role.type if with_role_type else None,
),
h_userids=[lms_user.h_userid, inactive_lms_user.h_userid]
if with_h_userids
else None,
Expand Down Expand Up @@ -164,8 +166,10 @@ def test_get_assignment_roster(
result = db_session.execute(
svc.get_assignment_roster(
assignment,
role_scope=lti_role.scope if with_role_scope else None,
role_type=lti_role.type if with_role_type else None,
include_role=(
lti_role.scope if with_role_scope else None,
lti_role.type if with_role_type else None,
),
h_userids=[lms_user.h_userid, inactive_lms_user.h_userid]
if with_h_userids
else None,
Expand Down Expand Up @@ -207,8 +211,10 @@ def test_get_segment_roster(
result = db_session.execute(
svc.get_segments_roster(
[lms_segment],
role_scope=lti_role.scope if with_role_scope else None,
role_type=lti_role.type if with_role_type else None,
include_role=(
lti_role.scope if with_role_scope else None,
lti_role.type if with_role_type else None,
),
h_userids=[lms_user.h_userid, inactive_lms_user.h_userid]
if with_h_userids
else None,
Expand Down
Loading