import numpy as np
import tensorflow as tf
from tensorflow import keras
from sklearn.model_selection import train_test_split
from dAIEdgeVLabAPI import dAIEdgeVLabAPI

#----------------------------------------------------
# Load the MNIST dataset and preprocess it 

DATASET = "dataset_mnist.bin"

# Load MNIST dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Normalize pixel values to [0,1]
x_train, x_test = x_train / 255.0, x_test / 255.0

# Expand dimensions for CNN input (from (28,28) to (28,28,1))
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis].astype(np.float32)

# Split training set into training (90%) and validation (10%) sets
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.1, random_state=42)


# Print dataset sizes
print(f"Training set: {x_train.shape[0]} samples")
print(f"Validation set: {x_val.shape[0]} samples")
print(f"Test set: {x_test.shape[0]} samples")

# Save the test dataset in binary format
x_test.tofile(DATASET)

#----------------------------------------------------
# Train a CNN model and convert it to TFLite

MODEL = "mnist_model.tflite"

model = keras.Sequential([
    keras.layers.Conv2D(32, kernel_size=(3,3), activation='relu', input_shape=(28,28,1)),
    keras.layers.MaxPooling2D(pool_size=(2,2)),
    keras.layers.Conv2D(64, kernel_size=(3,3), activation='relu'),
    keras.layers.MaxPooling2D(pool_size=(2,2)),
    keras.layers.Flatten(),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(10, activation='softmax')  # 10 output classes (digits 0-9)
])

# Compile the model
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Train the model
model.fit(x_train, y_train, epochs=5, validation_data=(x_val, y_val))

# Convert the model to TensorFlow Lite format
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# Save the model to a .tflite file
with open(MODEL, 'wb') as f:
    f.write(tflite_model)

#----------------------------------------------------
# Run the model on the dAIEdge VLab
    
# Hardware in the loop
TARGET = "rpi5"
RUNTIME = "tflite"

api = dAIEdgeVLabAPI("setup.yaml")

api.uploadDataset(DATASET)
id = api.startBenchmark(
    target = TARGET, 
    runtime = RUNTIME, 
    model_path = MODEL,
    dataset = DATASET
    )

print(f"Benchmark started for {TARGET} - {RUNTIME}")

result = api.waitBenchmarkResult(id, verbose=True)

#----------------------------------------------------
# Evaluate the results

nb_samples = len(x_test)

# Convert the binary file to an array of float32, then reshape it as (nb_samples, 10)
output_array = np.frombuffer(result["raw_output"], dtype=np.float32)
output_array = output_array.reshape(nb_samples, 10)

# Find the max argument in each nb_samples array of 10 float values
predictions = np.argmax(output_array, axis=1)
# Compute the accuracy of the prediction 
accuracy = np.mean(predictions == y_test)

print("Number of inference:", result["report"]["nb_inference"])
print("Mean inference time us:", result["report"]["inference_latency"]["mean"])
print("Accuracy:", accuracy) 