Problem utilising KerasLayer() for loading Pretrained Models | TensorFlow

Problem Description

I am trying to load the universal_sentence_encoder model available on the Kaggle model hub and utilise its pre-trained weights to fine-tune my model for downstream tasks. I had numerous tensor compatibility errors that I have resolved to get to here. If there an easier way or something I am missing out kindly help me out.
Thank you for your help :upside_down_face:!!!.

My Code

import tensorflow_hub as hub

# ChatGPT and R1 Solution which didn't work
# Avoid unsupported operations on string tensors
# tf.config.optimizer.set_jit(False)
# tf.config.set_soft_device_placement(True)

# Loading the Model
model_path = "/kaggle/input/universal-sentence-encoder/tensorflow2/universal-sentence-encoder/2"

class UniversalSentenceEncoder(tf.keras.layers.Layer):
    def __init__(self, model_path, model_trainable=True, **kwargs):
        super().__init__(**kwargs)
        self.hub_layer = hub.KerasLayer(
            model_path, trainable=model_trainable, input_shape=[], dtype=tf.string
        )
    
    def call(self, X):
        return self.hub_layer(X)

tf.keras.backend.clear_session() # Optional

# Instantiating and Loading the Universal Sentence Encoder
hub_model = UniversalSentenceEncoder(model_path, model_trainable=False)

# Input Layer
input_layer = tf.keras.layers.Input(shape=[], dtype=tf.string)

# Hub Layers
hub_encoding = hub_model(input_layer)

# Downstream Layers
dense_1 = tf.keras.layers.Dense(128, activation="elu")(hub_encoding)
dense_2 = tf.keras.layers.Dense(32, activation="elu")(dense_1)

# Output Layer
output_layer = tf.keras.layers.Dense(1, activation="sigmoid")(dense_2)

# Constructing the Model with the pretrained hub_model
universal_sentence_model = tf.keras.models.Model(
    inputs=[input_layer], outputs=[output_layer]
)

universal_sentence_model.compile(
    loss="binary_crossentropy", metrics=["accuracy"], optimizer="nadam"
)

model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
    "universal_sentence_model.keras", save_best_only=True, monitor="val_accuracy"
)

history = universal_sentence_model.fit(
    train_set, validation_data=valid_set, epochs=5, callbacks=[model_checkpoint]
)

Error Description

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-37-41b6527a52c0> in <cell line: 9>()
      7 )
      8 
----> 9 history = universal_sentence_model.fit(
     10     train_set, validation_data=valid_set, epochs=5, callbacks=[model_checkpoint]
     11 )

/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py in error_handler(*args, **kwargs)
    120             # To get the full stack trace, call:
    121             # `keras.config.disable_traceback_filtering()`
--> 122             raise e.with_traceback(filtered_tb) from None
    123         finally:
    124             del filtered_tb

/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     51   try:
     52     ctx.ensure_initialized()
---> 53     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     54                                         inputs, attrs, num_outputs)
     55   except core._NotOkStatusException as e:

InvalidArgumentError: Graph execution error:

Detected at node data defined at (most recent call last):
<stack traces unavailable>
Detected at node data defined at (most recent call last):
<stack traces unavailable>
Detected unsupported operations when trying to compile graph __inference_one_step_on_data_93486[] on XLA_GPU_JIT: _Arg (No registered '_Arg' OpKernel for XLA_GPU_JIT devices compatible with node {{node data}}
	 (OpKernel was found, but attributes didn't match) Requested Attributes: T=DT_STRING, _output_shapes=[[32]], _user_specified_name="data", index=0){{node data}}
The op is created at: 
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
File "/usr/local/lib/python3.10/dist-packages/colab_kernel_launcher.py", line 37, in <module>
File "/usr/local/lib/python3.10/dist-packages/traitlets/config/application.py", line 992, in launch_instance
File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py", line 619, in start
File "/usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py", line 195, in start
File "/usr/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
File "/usr/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
File "/usr/lib/python3.10/asyncio/events.py", line 80, in _run
File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 685, in <lambda>
File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 738, in _run_callback
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 825, in inner
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 786, in run
File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 361, in process_one
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 261, in dispatch_shell
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 539, in execute_request
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
File "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py", line 302, in do_execute
File "/usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py", line 539, in run_cell
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell
File "/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
File "<ipython-input-37-41b6527a52c0>", line 9, in <cell line: 9>
File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "/usr/local/lib/python3.10/dist-packages/keras/src/backend/tensorflow/trainer.py", line 320, in fit
File "/usr/local/lib/python3.10/dist-packages/keras/src/backend/tensorflow/trainer.py", line 121, in one_step_on_iterator
File "/usr/local/lib/python3.10/dist-packages/tensorflow/core/function/polymorphism/function_type.py", line 356, in placeholder_arguments
File "/usr/local/lib/python3.10/dist-packages/tensorflow/core/function/trace_type/default_types.py", line 250, in placeholder_value
File "/usr/local/lib/python3.10/dist-packages/tensorflow/core/function/trace_type/default_types.py", line 251, in <listcomp>
	tf2xla conversion failed while converting __inference_one_step_on_data_93486[]. Run with TF_DUMP_GRAPH_PREFIX=/path/to/dump/dir and --vmodule=xla_compiler=2 to obtain a dump of the compiled functions.
	 [[StatefulPartitionedCall]] [Op:__inference_one_step_on_iterator_93901]