tflite_infer.py 763 B

12345678910111213141516171819202122
  1. import numpy as np
  2. import tensorflow as tf
  3. # Load the TFLite model and allocate tensors
  4. interpreter = tf.lite.Interpreter(model_path="torch_script_model/doc_clean.tflite")
  5. interpreter.allocate_tensors()
  6. # Get input and output tensors
  7. input_details = interpreter.get_input_details()
  8. output_details = interpreter.get_output_details()
  9. # Test the model on random input data
  10. input_shape = input_details[0]['shape']
  11. input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
  12. interpreter.set_tensor(input_details[0]['index'], input_data)
  13. interpreter.invoke()
  14. # get_tensor() returns a copy of the tensor data
  15. # use tensor() in order to get a pointer to the tensor
  16. output_data = interpreter.get_tensor(output_details[0]['index'])
  17. print(output_data)