diff --git a/dev-env-requirements.txt b/dev-env-requirements.txt
index 4f70c88c806d58d2a5933b7f632152683fb3926f..52eed6ff723fc7d525774900a004f528fe576136 100644
--- a/dev-env-requirements.txt
+++ b/dev-env-requirements.txt
@@ -1,6 +1,6 @@
 -r requirements.txt
-graphene==2.1.8
-graphene-django==2.7.1
+graphene==3.0b7
+graphene-django==3.0.0b7
 pytest==4.6.3
 pytest-django==3.5.0
 pytest-cov==2.7.1
diff --git a/graphene_django_optimizer/field.py b/graphene_django_optimizer/field.py
index c6bd31011defcd5a5c17aefcc1b9890a6cbefca5..48165cc92c853b035b12fe18133a69e2c50754bb 100644
--- a/graphene_django_optimizer/field.py
+++ b/graphene_django_optimizer/field.py
@@ -1,3 +1,4 @@
+import types
 from graphene.types.field import Field
 from graphene.types.unmountedtype import UnmountedType
 
@@ -9,12 +10,12 @@ def field(field_type, *args, **kwargs):
         field_type = Field.mounted(field_type)
 
     optimization_hints = OptimizationHints(*args, **kwargs)
-    get_resolver = field_type.get_resolver
+    wrap_resolve = field_type.wrap_resolve
 
-    def get_optimized_resolver(parent_resolver):
-        resolver = get_resolver(parent_resolver)
+    def get_optimized_resolver(self, parent_resolver):
+        resolver = wrap_resolve(parent_resolver)
         resolver.optimization_hints = optimization_hints
         return resolver
 
-    field_type.get_resolver = get_optimized_resolver
+    field_type.wrap_resolve = types.MethodType(get_optimized_resolver, field_type)
     return field_type
diff --git a/graphene_django_optimizer/query.py b/graphene_django_optimizer/query.py
index 057125af0ab0ff0cd59d81992b2cbdede36fecee..e64f041f5e44cdea4989425914ef23fa5f87f7d0 100644
--- a/graphene_django_optimizer/query.py
+++ b/graphene_django_optimizer/query.py
@@ -8,20 +8,20 @@ from graphene import InputObjectType
 from graphene.types.generic import GenericScalar
 from graphene.types.resolver import default_resolver
 from graphene_django import DjangoObjectType
-from graphql import ResolveInfo
-from graphql.execution.base import (
-    get_field_def,
-)
+from graphql import GraphQLResolveInfo, GraphQLSchema
+from graphql.execution.execute import get_field_def
 from graphql.language.ast import (
-    FragmentSpread,
-    InlineFragment,
-    Variable,
+    FragmentSpreadNode,
+    InlineFragmentNode,
+    VariableNode,
 )
 from graphql.type.definition import (
     GraphQLInterfaceType,
     GraphQLUnionType,
 )
 
+from graphql.pyutils import Path
+
 from .utils import is_iterable
 
 
@@ -31,7 +31,7 @@ def query(queryset, info, **options):
 
     Arguments:
         - queryset (Django QuerySet object) - The queryset to be optimized
-        - info (GraphQL ResolveInfo object) - This is passed by the graphene-django resolve methods
+        - info (GraphQL GraphQLResolveInfo object) - This is passed by the graphene-django resolve methods
         - **options - optimization options/settings
             - disable_abort_only (boolean) - in case the objecttype contains any extra fields,
                                              then this will keep the "only" optimization enabled.
@@ -54,7 +54,7 @@ class QueryOptimizer(object):
         field_def = get_field_def(info.schema, info.parent_type, info.field_name)
         store = self._optimize_gql_selections(
             self._get_type(field_def),
-            info.field_asts[0],
+            info.field_nodes[0],
             # info.parent_type,
         )
         return store.optimize_queryset(queryset)
@@ -65,9 +65,16 @@ class QueryOptimizer(object):
             a_type = a_type.of_type
         return a_type
 
+    def _get_graphql_schema(self, schema):
+        if isinstance(schema, GraphQLSchema):
+            return schema
+        else:
+            return schema.graphql_schema
+
     def _get_possible_types(self, graphql_type):
         if isinstance(graphql_type, (GraphQLInterfaceType, GraphQLUnionType)):
-            return self.root_info.schema.get_possible_types(graphql_type)
+            graphql_schema = self._get_graphql_schema(self.root_info.schema)
+            return graphql_schema.get_possible_types(graphql_type)
         else:
             return (graphql_type,)
 
@@ -80,7 +87,8 @@ class QueryOptimizer(object):
 
     def handle_inline_fragment(self, selection, schema, possible_types, store):
         fragment_type_name = selection.type_condition.name.value
-        fragment_type = schema.get_type(fragment_type_name)
+        graphql_schema = self._get_graphql_schema(schema)
+        fragment_type = graphql_schema.get_type(fragment_type_name)
         fragment_possible_types = self._get_possible_types(fragment_type)
         for fragment_possible_type in fragment_possible_types:
             fragment_model = fragment_possible_type.graphene_type._meta.model
@@ -120,14 +128,16 @@ class QueryOptimizer(object):
             return store
         optimized_fields_by_model = {}
         schema = self.root_info.schema
-        graphql_type = schema.get_graphql_type(field_type.graphene_type)
+        graphql_schema = self._get_graphql_schema(schema)
+        graphql_type = graphql_schema.get_type(field_type.name)
+
         possible_types = self._get_possible_types(graphql_type)
         for selection in selection_set.selections:
-            if isinstance(selection, InlineFragment):
+            if isinstance(selection, InlineFragmentNode):
                 self.handle_inline_fragment(selection, schema, possible_types, store)
             else:
                 name = selection.name.value
-                if isinstance(selection, FragmentSpread):
+                if isinstance(selection, FragmentSpreadNode):
                     self.handle_fragment_spread(store, name, field_type)
                 else:
                     for possible_type in possible_types:
@@ -176,7 +186,7 @@ class QueryOptimizer(object):
             store.abort_only_optimization()
 
     def _optimize_field_by_name(self, store, model, selection, field_def):
-        name = self._get_name_from_resolver(field_def.resolver)
+        name = self._get_name_from_resolver(field_def.resolve)
         if not name:
             return False
         model_field = self._get_model_field_from_name(model, name)
@@ -215,7 +225,7 @@ class QueryOptimizer(object):
         return getattr(resolver, "optimization_hints", None)
 
     def _get_value(self, info, value):
-        if isinstance(value, Variable):
+        if isinstance(value, VariableNode):
             var_name = value.name.value
             value = info.variable_values.get(var_name)
             return value
@@ -225,7 +235,7 @@ class QueryOptimizer(object):
             return GenericScalar.parse_literal(value)
 
     def _optimize_field_by_hints(self, store, selection, field_def, parent_type):
-        optimization_hints = self._get_optimization_hints(field_def.resolver)
+        optimization_hints = self._get_optimization_hints(field_def.resolve)
         if not optimization_hints:
             return False
         info = self._create_resolve_info(
@@ -316,17 +326,19 @@ class QueryOptimizer(object):
         )
 
     def _create_resolve_info(self, field_name, field_asts, return_type, parent_type):
-        return ResolveInfo(
+        return GraphQLResolveInfo(
             field_name,
             field_asts,
             return_type,
             parent_type,
+            Path(None, 0, None),
             schema=self.root_info.schema,
             fragments=self.root_info.fragments,
             root_value=self.root_info.root_value,
             operation=self.root_info.operation,
             variable_values=self.root_info.variable_values,
             context=self.root_info.context,
+            is_awaitable=self.root_info.is_awaitable,
         )
 
 
diff --git a/tests/graphql_utils.py b/tests/graphql_utils.py
index 6b061e43ea1f9a06e9945ff01203bb7aaf4a9aad..c6f1a8c00f0c03bc16eba03c8565a6a48c8afb0a 100644
--- a/tests/graphql_utils.py
+++ b/tests/graphql_utils.py
@@ -1,40 +1,38 @@
 from graphql import (
-    ResolveInfo,
+    GraphQLResolveInfo,
     Source,
     Undefined,
     parse,
 )
-from graphql.execution.base import (
+from graphql.execution.execute import (
     ExecutionContext,
-    collect_fields,
     get_field_def,
-    get_operation_root_type,
 )
-from graphql.pyutils.default_ordered_dict import DefaultOrderedDict
+from graphql.utilities import get_operation_root_type
+from collections import defaultdict
+
+from graphql.pyutils import Path
 
 
 def create_execution_context(schema, request_string, variables=None):
     source = Source(request_string, "GraphQL request")
     document_ast = parse(source)
-    return ExecutionContext(
+    return ExecutionContext.build(
         schema,
         document_ast,
         root_value=None,
         context_value=None,
-        variable_values=variables,
+        raw_variable_values=variables,
         operation_name=None,
-        executor=None,
         middleware=None,
-        allow_subscriptions=False,
     )
 
 
 def get_field_asts_from_execution_context(exe_context):
-    fields = collect_fields(
-        exe_context,
+    fields = exe_context.collect_fields(
         type,
         exe_context.operation.selection_set,
-        DefaultOrderedDict(list),
+        defaultdict(list),
         set(),
     )
     # field_asts = next(iter(fields.values()))
@@ -42,7 +40,7 @@ def get_field_asts_from_execution_context(exe_context):
     return field_asts
 
 
-def create_resolve_info(schema, request_string, variables=None):
+def create_resolve_info(schema, request_string, variables=None, return_type=None):
     exe_context = create_execution_context(schema, request_string, variables)
     parent_type = get_operation_root_type(schema, exe_context.operation)
     field_asts = get_field_asts_from_execution_context(exe_context)
@@ -50,24 +48,26 @@ def create_resolve_info(schema, request_string, variables=None):
     field_ast = field_asts[0]
     field_name = field_ast.name.value
 
-    field_def = get_field_def(schema, parent_type, field_name)
-    if not field_def:
-        return Undefined
-    return_type = field_def.type
+    if return_type is None:
+        field_def = get_field_def(schema, parent_type, field_name)
+        if not field_def:
+            return Undefined
+        return_type = field_def.type
 
     # The resolve function's optional third argument is a context value that
     # is provided to every resolve function within an execution. It is commonly
     # used to represent an authenticated user, or request-specific caches.
-    context = exe_context.context_value
-    return ResolveInfo(
+    return GraphQLResolveInfo(
         field_name,
         field_asts,
         return_type,
         parent_type,
-        schema=schema,
-        fragments=exe_context.fragments,
-        root_value=exe_context.root_value,
-        operation=exe_context.operation,
-        variable_values=exe_context.variable_values,
-        context=context,
+        Path(None, 0, None),
+        schema,
+        exe_context.fragments,
+        exe_context.root_value,
+        exe_context.operation,
+        exe_context.variable_values,
+        exe_context.context_value,
+        exe_context.is_awaitable,
     )
diff --git a/tests/schema.py b/tests/schema.py
index 66ffd9a1af25f2dbd14bf49b76282a6da263f529..fc30203a53f33d2c468599d751538d58c1b02c78 100644
--- a/tests/schema.py
+++ b/tests/schema.py
@@ -99,6 +99,7 @@ class BaseItemType(OptimizedDjangoObjectType):
 
     class Meta:
         model = Item
+        fields = "__all__"
 
     @gql_optimizer.resolver_hints(
         model_field="children",
@@ -110,6 +111,8 @@ class BaseItemType(OptimizedDjangoObjectType):
 class ItemNode(BaseItemType):
     class Meta:
         model = Item
+        fields = "__all__"
+
         interfaces = (
             graphene.relay.Node,
             ItemInterface,
@@ -119,16 +122,19 @@ class ItemNode(BaseItemType):
 class SomeOtherItemType(OptimizedDjangoObjectType):
     class Meta:
         model = SomeOtherItem
+        fields = "__all__"
 
 
 class OtherItemType(OptimizedDjangoObjectType):
     class Meta:
         model = OtherItem
+        fields = "__all__"
 
 
 class ItemType(BaseItemType):
     class Meta:
         model = Item
+        fields = "__all__"
         interfaces = (ItemInterface,)
 
 
@@ -144,29 +150,34 @@ class DetailedInterface(graphene.Interface):
 class DetailedItemType(ItemType):
     class Meta:
         model = DetailedItem
+        fields = "__all__"
         interfaces = (ItemInterface, DetailedInterface)
 
 
 class RelatedItemType(ItemType):
     class Meta:
         model = RelatedItem
+        fields = "__all__"
         interfaces = (ItemInterface,)
 
 
 class ExtraDetailedItemType(DetailedItemType):
     class Meta:
         model = ExtraDetailedItem
+        fields = "__all__"
         interfaces = (ItemInterface,)
 
 
 class RelatedOneToManyItemType(OptimizedDjangoObjectType):
     class Meta:
         model = RelatedOneToManyItem
+        fields = "__all__"
 
 
 class UnrelatedModelType(OptimizedDjangoObjectType):
     class Meta:
         model = UnrelatedModel
+        fields = "__all__"
         interfaces = (DetailedInterface,)
 
 
@@ -200,6 +211,21 @@ class Query(graphene.ObjectType):
         return gql_optimizer.query(OtherItemType.objects.all(), info)
 
 
-schema = graphene.Schema(
-    query=Query, types=(UnrelatedModelType,), mutation=DummyItemMutation
-)
+class Schema(graphene.Schema):
+    @property
+    def query_type(self):
+        return self.graphql_schema.get_type("Query")
+
+    @property
+    def mutation_type(self):
+        return self.graphql_schema.get_type("Mutation")
+
+    @property
+    def subscription_type(self):
+        return self.graphql_schema.get_type("Subscription")
+
+    def get_type(self, _type):
+        return self.graphql_schema.get_type(_type)
+
+
+schema = Schema(query=Query, types=(UnrelatedModelType,), mutation=DummyItemMutation)
diff --git a/tests/test_field.py b/tests/test_field.py
index d7a8957bd73f4ad16a417d0588c57a03dedd6f78..71d617e90b52bbe2ad67d554b83535f2afac83a9 100644
--- a/tests/test_field.py
+++ b/tests/test_field.py
@@ -1,13 +1,13 @@
+import pytest
 import graphene_django_optimizer as gql_optimizer
 
 from .graphql_utils import create_resolve_info
-from .models import (
-    Item,
-)
+from .models import Item
 from .schema import schema
 from .test_utils import assert_query_equality
 
 
+@pytest.mark.django_db
 def test_should_optimize_non_django_field_if_it_has_an_optimization_hint_in_the_field():
     info = create_resolve_info(
         schema,
@@ -29,6 +29,7 @@ def test_should_optimize_non_django_field_if_it_has_an_optimization_hint_in_the_
     assert_query_equality(items, optimized_items)
 
 
+@pytest.mark.django_db
 def test_should_optimize_with_only_hint():
     info = create_resolve_info(
         schema,
diff --git a/tests/test_query.py b/tests/test_query.py
index a821ec59ef94dc0d60d3fd94a888c5d0f796c3c5..fa7339d574bd306d2f084321b083bd97ddf4c1f1 100644
--- a/tests/test_query.py
+++ b/tests/test_query.py
@@ -15,7 +15,7 @@ from .schema import schema
 from .test_utils import assert_query_equality
 
 
-# @pytest.mark.django_db
+@pytest.mark.django_db
 def test_should_reduce_number_of_queries_by_using_select_related():
     # parent = Item.objects.create(name='foo')
     # Item.objects.create(name='bar', parent=parent)
@@ -39,7 +39,7 @@ def test_should_reduce_number_of_queries_by_using_select_related():
     assert_query_equality(items, optimized_items)
 
 
-# @pytest.mark.django_db
+@pytest.mark.django_db
 def test_should_reduce_number_of_queries_by_using_prefetch_related():
     # parent = Item.objects.create(name='foo')
     # Item.objects.create(name='bar', parent=parent)
@@ -64,7 +64,7 @@ def test_should_reduce_number_of_queries_by_using_prefetch_related():
     assert_query_equality(items, optimized_items)
 
 
-# @pytest.mark.django_db
+@pytest.mark.django_db
 def test_should_optimize_scalar_model_fields():
     # Item.objects.create(name='foo')
     info = create_resolve_info(
@@ -84,7 +84,7 @@ def test_should_optimize_scalar_model_fields():
     assert_query_equality(items, optimized_items)
 
 
-# @pytest.mark.django_db
+@pytest.mark.django_db
 def test_should_optimize_scalar_foreign_key_model_fields():
     # parent = Item.objects.create(name='foo')
     # Item.objects.create(name='bar', parent=parent)
@@ -105,7 +105,7 @@ def test_should_optimize_scalar_foreign_key_model_fields():
     assert_query_equality(items, optimized_items)
 
 
-# @pytest.mark.django_db
+@pytest.mark.django_db
 def test_should_not_try_to_optimize_non_model_fields():
     # Item.objects.create(name='foo')
     info = create_resolve_info(
@@ -125,7 +125,7 @@ def test_should_not_try_to_optimize_non_model_fields():
     assert_query_equality(items, optimized_items)
 
 
-# @pytest.mark.django_db
+@pytest.mark.django_db
 def test_should_not_try_to_optimize_non_field_model_fields():
     # Item.objects.create(name='foo')
     info = create_resolve_info(
@@ -145,6 +145,7 @@ def test_should_not_try_to_optimize_non_field_model_fields():
     assert_query_equality(items, optimized_items)
 
 
+@pytest.mark.django_db
 def test_should_try_to_optimize_non_field_model_fields_when_disabling_abort_only():
     # Item.objects.create(name='foo')
     info = create_resolve_info(
@@ -164,7 +165,7 @@ def test_should_try_to_optimize_non_field_model_fields_when_disabling_abort_only
     assert_query_equality(items, optimized_items)
 
 
-# @pytest.mark.django_db
+@pytest.mark.django_db
 def test_should_optimize_when_using_fragments():
     # parent = Item.objects.create(name='foo')
     # Item.objects.create(name='bar', parent=parent)
@@ -190,7 +191,7 @@ def test_should_optimize_when_using_fragments():
     assert_query_equality(items, optimized_items)
 
 
-# @pytest.mark.django_db
+@pytest.mark.django_db
 def test_should_prefetch_field_with_camel_case_name():
     # item = Item.objects.create(name='foo')
     # Item.objects.create(name='bar', item=item)
@@ -215,7 +216,7 @@ def test_should_prefetch_field_with_camel_case_name():
     assert_query_equality(items, optimized_items)
 
 
-# @pytest.mark.django_db
+@pytest.mark.django_db
 def test_should_select_nested_related_fields():
     # parent = Item.objects.create(name='foo')
     # parent = Item.objects.create(name='bar', parent=parent)
@@ -243,7 +244,7 @@ def test_should_select_nested_related_fields():
     assert_query_equality(items, optimized_items)
 
 
-# @pytest.mark.django_db
+@pytest.mark.django_db
 def test_should_prefetch_nested_related_fields():
     # parent = Item.objects.create(name='foo')
     # parent = Item.objects.create(name='bar', parent=parent)
@@ -273,7 +274,7 @@ def test_should_prefetch_nested_related_fields():
     assert_query_equality(items, optimized_items)
 
 
-# @pytest.mark.django_db
+@pytest.mark.django_db
 def test_should_prefetch_nested_select_related_field():
     # parent = Item.objects.create(name='foo')
     # item = Item.objects.create(name='foobar')
@@ -304,7 +305,7 @@ def test_should_prefetch_nested_select_related_field():
     assert_query_equality(items, optimized_items)
 
 
-# @pytest.mark.django_db
+@pytest.mark.django_db
 def test_should_select_nested_prefetch_related_field():
     # parent = Item.objects.create(name='foo')
     # Item.objects.create(name='bar', parent=parent)
@@ -333,7 +334,7 @@ def test_should_select_nested_prefetch_related_field():
     assert_query_equality(items, optimized_items)
 
 
-# @pytest.mark.django_db
+@pytest.mark.django_db
 def test_should_select_nested_prefetch_and_select_related_fields():
     # parent = Item.objects.create(name='foo')
     # item = Item.objects.create(name='bar_item')
@@ -368,7 +369,7 @@ def test_should_select_nested_prefetch_and_select_related_fields():
     assert_query_equality(items, optimized_items)
 
 
-# @pytest.mark.django_db
+@pytest.mark.django_db
 def test_should_fetch_fields_of_related_field():
     # parent = Item.objects.create(name='foo')
     # Item.objects.create(name='bar', parent=parent)
@@ -391,7 +392,7 @@ def test_should_fetch_fields_of_related_field():
     assert_query_equality(items, optimized_items)
 
 
-# @pytest.mark.django_db
+@pytest.mark.django_db
 def test_should_fetch_fields_of_prefetched_field():
     # parent = Item.objects.create(name='foo')
     # Item.objects.create(name='bar', parent=parent)
@@ -417,7 +418,7 @@ def test_should_fetch_fields_of_prefetched_field():
     assert_query_equality(items, optimized_items)
 
 
-# @pytest.mark.django_db
+@pytest.mark.django_db
 def test_should_fetch_child_model_field_for_interface_field():
     # Item.objects.create(name='foo')
     # ExtraDetailedItem.objects.create(name='foo', extra_detail='test')
@@ -443,7 +444,7 @@ def test_should_fetch_child_model_field_for_interface_field():
 
 
 @pytest.mark.skip(reason="will be tested in the future")
-# @pytest.mark.django_db
+@pytest.mark.django_db
 def test_should_fetch_field_of_child_model_when_parent_has_no_optimized_field():
     # Item.objects.create(name='foo')
     # DetailedItem.objects.create(name='foo', item_type='test')
@@ -466,6 +467,7 @@ def test_should_fetch_field_of_child_model_when_parent_has_no_optimized_field():
     assert_query_equality(items, optimized_items)
 
 
+@pytest.mark.django_db
 def test_should_fetch_field_inside_interface_fragment():
     info = create_resolve_info(
         schema,
@@ -488,6 +490,7 @@ def test_should_fetch_field_inside_interface_fragment():
     assert_query_equality(items, optimized_items)
 
 
+@pytest.mark.django_db
 def test_should_use_nested_prefetch_related_while_also_selecting_only_required_fields():
     info = create_resolve_info(
         schema,
@@ -565,7 +568,7 @@ def test_should_check_reverse_relations_add_foreign_key():
     assert len(expected_query_capture) == len(optimized_query_capture)
 
 
-# @pytest.mark.django_db
+@pytest.mark.django_db
 def test_should_only_use_the_only_and_not_select_related():
     info = create_resolve_info(
         schema,
diff --git a/tests/test_relay.py b/tests/test_relay.py
index 1e003f4208008c7dc213a717ad021b405b937365..bfd75e8e3c83aac5769dcc82b6baaec27b6be341 100644
--- a/tests/test_relay.py
+++ b/tests/test_relay.py
@@ -3,9 +3,7 @@ import pytest
 import graphene_django_optimizer as gql_optimizer
 
 from .graphql_utils import create_resolve_info
-from .models import (
-    Item,
-)
+from .models import Item
 from .schema import schema
 from .test_utils import assert_query_equality
 
@@ -37,6 +35,7 @@ def test_should_return_valid_result_in_a_relay_query():
     assert result.data["relayItems"]["edges"][0]["node"]["name"] == "foo"
 
 
+@pytest.mark.django_db
 def test_should_reduce_number_of_queries_in_relay_schema_by_using_select_related():
     info = create_resolve_info(
         schema,
@@ -62,6 +61,7 @@ def test_should_reduce_number_of_queries_in_relay_schema_by_using_select_related
     assert_query_equality(items, optimized_items)
 
 
+@pytest.mark.django_db
 def test_should_reduce_number_of_queries_in_relay_schema_by_using_prefetch_related():
     info = create_resolve_info(
         schema,
@@ -88,6 +88,7 @@ def test_should_reduce_number_of_queries_in_relay_schema_by_using_prefetch_relat
     assert_query_equality(items, optimized_items)
 
 
+@pytest.mark.django_db
 def test_should_optimize_query_by_only_requesting_id_field():
     try:
         from django.db.models import DEFERRED  # noqa: F401
diff --git a/tests/test_resolver.py b/tests/test_resolver.py
index 7a063a43d5f214fcdc981b42aaac41c5801069c7..bbeb1fad7f40a848c2968e5804454ff9fc32902e 100644
--- a/tests/test_resolver.py
+++ b/tests/test_resolver.py
@@ -4,14 +4,12 @@ from django.db.models import Prefetch
 import graphene_django_optimizer as gql_optimizer
 
 from .graphql_utils import create_resolve_info
-from .models import (
-    Item,
-)
+from .models import Item
 from .schema import schema
 from .test_utils import assert_query_equality
 
 
-# @pytest.mark.django_db
+@pytest.mark.django_db
 def test_should_optimize_non_django_field_if_it_has_an_optimization_hint_in_the_resolver():
     # parent = Item.objects.create(name='foo')
     # Item.objects.create(name='bar', parent=parent)
@@ -39,7 +37,7 @@ def test_should_optimize_non_django_field_if_it_has_an_optimization_hint_in_the_
     assert_query_equality(items, optimized_items)
 
 
-# @pytest.mark.django_db
+@pytest.mark.django_db
 def test_should_optimize_with_prefetch_related_as_a_string():
     # parent = Item.objects.create(name='foo')
     # Item.objects.create(name='bar', parent=parent)
@@ -62,6 +60,7 @@ def test_should_optimize_with_prefetch_related_as_a_string():
     assert_query_equality(items, optimized_items)
 
 
+@pytest.mark.django_db
 def test_should_optimize_with_prefetch_related_as_a_function():
     # parent = Item.objects.create(name='foo')
     # Item.objects.create(name='bar', parent=parent)
diff --git a/tests/test_types.py b/tests/test_types.py
index 345efd58ae7b6fcc9a3d57d62411d5f7685f0f4b..706151f15dd30180d16b1ec2a4fc5c3d5e3d2308 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -25,9 +25,9 @@ def test_should_optimize_the_single_node(mocked_optimizer):
             }
         }
     """,
+        return_type=schema.graphql_schema.get_type("SomeOtherItemType"),
     )
 
-    info.return_type = schema.get_type("SomeOtherItemType")
     result = SomeOtherItemType.get_node(info, 7)
 
     assert result, "Expected the item to be found and returned"
@@ -55,9 +55,9 @@ def test_should_return_none_when_node_is_not_resolved(mocked_optimizer):
             }
         }
     """,
+        return_type=schema.graphql_schema.get_type("SomeOtherItemType"),
     )
 
-    info.return_type = schema.get_type("SomeOtherItemType")
     qs = SomeOtherItem.objects
     mocked_optimizer.return_value = qs
 
@@ -84,9 +84,9 @@ def test_mutating_should_not_optimize(mocked_optimizer):
             }
         }
     """,
+        return_type=schema.graphql_schema.get_type("SomeOtherItemType"),
     )
 
-    info.return_type = schema.get_type("SomeOtherItemType")
     result = DummyItemMutation.mutate(info, to_global_id("ItemNode", 7))
     assert result
     assert result.pk == 7
@@ -111,9 +111,9 @@ def test_should_optimize_the_queryset(mocked_optimizer):
             }
         }
     """,
+        return_type=schema.graphql_schema.get_type("SomeOtherItemType"),
     )
 
-    info.return_type = schema.get_type("SomeOtherItemType")
     qs = SomeOtherItem.objects.filter(pk=7)
     result = SomeOtherItemType.get_queryset(qs, info).get()