diff --git a/graphene_django_optimizer/types.py b/graphene_django_optimizer/types.py index 604453ba546d27dc6ccddd629c8423bbe8fe8c7a..852182ce19c0f2522a37979dc9407e128a35d9a4 100644 --- a/graphene_django_optimizer/types.py +++ b/graphene_django_optimizer/types.py @@ -15,29 +15,8 @@ class OptimizedDjangoObjectType(DjangoObjectType): and resolver_info.return_type.graphene_type is cls) @classmethod - def get_optimized_node(cls, info, qs, pk): - return query(qs, info).get(pk=pk) - - @classmethod - def maybe_optimize(cls, info, qs, pk): - try: - if cls.can_optimize_resolver(info): - return cls.get_optimized_node(info, qs, pk) - return qs.get(pk=pk) - except cls._meta.model.DoesNotExist: - return None - - @classmethod - def get_node(cls, info, id): - """ - Bear in mind that if you are overriding this method get_node(info, pk), - you should always call maybe_optimize(info, qs, pk) - and never directly call get_optimized_node(info, qs, pk) as it would - result to the node being attempted to be optimized when it is not - supposed to actually get optimized. - - :param info: - :param id: - :return: - """ - return cls.maybe_optimize(info, cls._meta.model.objects, id) + def get_queryset(cls, queryset, info): + queryset = super(OptimizedDjangoObjectType, cls).get_queryset(queryset, info) + if cls.can_optimize_resolver(info): + queryset = query(queryset, info) + return queryset diff --git a/tests/test_types.py b/tests/test_types.py index cb33d23387b9a9e3d159f1bd8a0c0bcc9d096182..42b7b4ec3ec3244b74ee01328c7da3e8b2c04589 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -83,3 +83,31 @@ def test_mutating_should_not_optimize(mocked_optimizer): assert result assert result.pk == 7 assert mocked_optimizer.call_count == 0 + + +@pytest.mark.django_db +@patch('graphene_django_optimizer.types.query', + return_value=SomeOtherItem.objects) +def test_should_optimize_the_queryset(mocked_optimizer): + SomeOtherItem.objects.create(pk=7, name='Hello') + + info = create_resolve_info(schema, ''' + query ItemDetails { + someOtherItems(id: $id) { + id + foo + parent { + id + } + } + } + ''') + + info.return_type = schema.get_type('SomeOtherItemType') + qs = SomeOtherItem.objects.filter(pk=7) + result = SomeOtherItemType.get_queryset(qs, info).get() + + assert result, 'Expected the item to be found and returned' + assert result.pk == 7, 'The item is not the correct one' + + mocked_optimizer.assert_called_once_with(qs, info)