from daiedge_vlab import dAIEdgeVLabAPI, OnDeviceTrainingConfig
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

SETUP_FILE = "setup.yaml"

TARGET = 'rpi5'
RUNTIME = 'tflite'
MODEL = 'model.tflite'

TRAIN_NB_SAMPLES = 10000
TEST_NB_SAMPLES = 2000

RESULT_DIR = "./results"

############################################
# Step 1: Load, Preprocess MNIST, and Save Data
############################################

def prepare_datasets():
    # Load MNIST dataset.
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

    # Preprocess: Convert images to float32, normalize to [0,1], and add channel dimension.
    x_train = (x_train.astype(np.float32) / 255.0)[..., np.newaxis]
    x_test  = (x_test.astype(np.float32) / 255.0)[..., np.newaxis]

    # Reduce the size of the dataset for faster training (optional).

    x_train = x_train[:TRAIN_NB_SAMPLES]  
    y_train = y_train[:TRAIN_NB_SAMPLES]
    x_test  = x_test[:TEST_NB_SAMPLES]    
    y_test  = y_test[:TEST_NB_SAMPLES]

    # Preprocess labels.
    y_train = y_train.astype(np.int32)
    y_test = y_test.astype(np.int32)

    # Save preprocessed data as binary files.
    x_train.tofile("mnist_train_input.bin")
    y_train.tofile("mnist_train_target.bin")
    x_test.tofile("mnist_test_input.bin")
    y_test.tofile("mnist_test_target.bin")

    print(len(x_train), len(y_train), len(x_test), len(y_test))
    print("Preprocessed train/test data saved as .bin files.")
    return (x_train, y_train), (x_test, y_test)

############################################
# Step 2: Define a Simple Keras Model.
############################################
def create_model():
    inputs = tf.keras.Input(shape=(28, 28, 1), name="input")
    x = tf.keras.layers.Conv2D(32, 3, activation="relu")(inputs)
    x = tf.keras.layers.MaxPooling2D()(x)
    x = tf.keras.layers.Conv2D(64, 3, activation="relu")(x)
    x = tf.keras.layers.MaxPooling2D()(x)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(128, activation="relu")(x)
    outputs = tf.keras.layers.Dense(10, activation="softmax", name="output")(x)
    return tf.keras.Model(inputs, outputs)


############################################
# Step 3: Define Custom Training and Inference Signatures.
############################################
# Training signature: performs one training step.
@tf.function(input_signature=[
    tf.TensorSpec([None, 28, 28, 1], tf.float32, name="x"),
    tf.TensorSpec([None], tf.int32, name="y")
])
def train_step(x, y):
    with tf.GradientTape() as tape:
        predictions = model(x, training=True)
        loss = tf.keras.losses.sparse_categorical_crossentropy(y, predictions)
        loss = tf.reduce_mean(loss)
    gradients = tape.gradient(loss, model.trainable_variables)
    model.optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return {"loss": loss}

# Inference signature: performs a forward pass (inference).
@tf.function(input_signature=[
    tf.TensorSpec([None, 28, 28, 1], tf.float32, name="input")
])
def inference(x):
    predictions = model(x, training=False)
    return {"output": predictions}

# Save signature: receives a checkpoint path as input and returns a simple status.
@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
def save(checkpoint_path):
    tensor_names = [weight.name for weight in model.weights]
    tensors_to_save = [weight.read_value() for weight in model.weights]
    tf.raw_ops.Save(
        filename=checkpoint_path, tensor_names=tensor_names,
        data=tensors_to_save, name='save')
    return {"status": tf.constant("saved")}

############################################
# Step 3: Define a Function to Restore Checkpoint Variables.
############################################
def restore_raw_checkpoint(model, ckpt_prefix):
    """Load variables saved with tf.raw_ops.Save into *model*."""
    var_names  = [w.name for w in model.weights]  
    dtypes     = [w.dtype for w in model.weights]

    print(f"Restoring weights from {var_names}...")

    # Restore every tensor in one call
    restored = tf.raw_ops.RestoreV2(
        prefix            = ckpt_prefix,
        tensor_names      = var_names,
        shape_and_slices  = [""] * len(var_names),
        dtypes            = dtypes
    )

    # Copy restored values back into the Keras weights
    for tensor, weight in zip(restored, model.weights):
        weight.assign(tensor)

    print(f"OK: Weights restored from {ckpt_prefix}")

if __name__ == '__main__':

    # Load and preprocess the MNIST dataset.
    (x_train, y_train), (x_test, y_test) = prepare_datasets()

    # Create the nmodel 
    model = create_model()

    model.compile(optimizer="adam",
                loss="sparse_categorical_crossentropy",
                metrics=["accuracy"])

    # Optional - pre-train the model (adjust epochs as needed).
    # model.fit(x_train, y_train, epochs=1, validation_data=(x_test, y_test))

    # Make all layers trainable.
    model.trainable = True  

    # Alternatively, you can set specific layers to be trainable.
    # For example, if you want to train only the last layer:
    #for layer in model.layers:
    #    layer.trainable = False
    #model.layers[-1].trainable = True

    ############################################
    # Step 4: Save the Model as a SavedModel with Both Signatures.
    ############################################
    # Both signatures are exported:
    # - "train" for on-device training,
    # - "serving_default" for inference.
    model.save("mnist_model", signatures={
        "train": train_step.get_concrete_function(),
        "serving_default": inference.get_concrete_function(),
        "save": save.get_concrete_function()
    })
    print("SavedModel with training and inference signatures saved to 'mnist_model'.")

    ############################################
    # Step 5: Convert the SavedModel to a TFLite Model for On‑Device Training.
    ############################################
    converter = tf.lite.TFLiteConverter.from_saved_model("mnist_model")
    # Retain mutable (resource) variables required for training.
    converter.experimental_enable_resource_variables = True
    # Optionally support extra TF ops.
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS,
        tf.lite.OpsSet.SELECT_TF_OPS
    ]
    tflite_model = converter.convert()

    # Save the TFLite model.
    with open(MODEL, "wb") as f:
        f.write(tflite_model)
    print("TFLite model with on-device training support saved as 'model.tflite'.")

    ############################################
    # Step 6: Upload the Model and Start the Benchmark.
    ############################################
    # Log in to the dAIEdge VLab API.
    api = dAIEdgeVLabAPI(SETUP_FILE)

    # Parse the configuration file for on-device training.
    # Make sure to adjust the path to your config file.
    config = OnDeviceTrainingConfig("./config.json")
    
    # Start the benchmark with the specified target, runtime, dataset and model.
    id = api.startOdtBenchmark(
        TARGET,
        RUNTIME,
        MODEL,
        config
    )

    # Wait for the benchmark to finish and save the results.
    print(f"Benchmark started with ID: {id}")
    
    r = api.waitBenchmarkResult(id, save_path=RESULT_DIR, verbose=True)
    print("Benchmark finished:", r)

    ############################################
    # Step 7: Plot the metrics gathered during the on-device training.
    ############################################

    # data lists (epoch, value)
    train_loss = [(l["epoch_index"], l["loss"])  for l in r["report"]["loss_train"]["epochs"]]
    test_loss  = [(l["epoch_index"], l["loss"])  for l in r["report"]["loss_test"]["epochs"]]
    train_time = [(l["epoch_index"], l["time"])  for l in r["report"]["loss_train"]["epochs"]]

    # unpack data lists
    # epochs, tr_loss_vals, te_loss_vals, tr_time_vals
    epochs,  tr_loss_vals = zip(*train_loss)
    _,       te_loss_vals = zip(*test_loss)
    _,       tr_time_vals = zip(*train_time)
    tr_avrge_time = np.mean(tr_time_vals)

    # side-by-side figure
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 5), sharex=True)

    # left panel - losses
    axes[0].plot(epochs, tr_loss_vals, label="Train Loss")
    axes[0].plot(epochs, te_loss_vals, label="Test Loss")
    axes[0].set_xlabel("Epoch")
    axes[0].set_ylabel("Loss")
    axes[0].set_title("Loss over Epochs")
    axes[0].legend()

    # right panel - time per epoch
    axes[1].plot(epochs, tr_time_vals, label="Train Time", color="orange")
    axes[1].axhline(tr_avrge_time, color="red", linestyle="--", label="Avg Time")
    axes[1].set_xlabel("Epoch")
    axes[1].set_ylabel("Time (s)")
    axes[1].set_title("Training Time per Epoch")
    axes[1].legend()

    fig.tight_layout()
    fig.savefig("training_metrics.png", dpi=150)


    ############################################
    # Step 8: Restore the Model and Evaluate on Test Set.
    ############################################

    # Locate the model output in the results.
    model_res = r["model_output"]
    print(f"Model output: {model_res}")
    ckpt_path = f"{RESULT_DIR}/{model_res}"  
    assert tf.io.gfile.exists(ckpt_path), "Checkpoint not found!"

    # Restore the model
    tf.keras.backend.clear_session()

    # Recreate a fresh model instance 
    model = create_model()
    model.compile(optimizer="adam",
                loss="sparse_categorical_crossentropy",
                metrics=["accuracy"])
    
    # Restore the weights from the checkpoint.
    restore_raw_checkpoint(model, ckpt_path)

    # Evaluate the model on the test set.
    loss, acc = model.evaluate(x_test, y_test, verbose=1)
    print("loss:", loss)    
    print(f"Accuracy after on-device fine-tuning: {acc:.4f}")

    # Save the updated model as a TFLite model with the new weights.
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    lite_trained = converter.convert()
    with open("model_trained.tflite", "wb") as f:
        f.write(lite_trained)
    print("Wrote model_trained.tflite with updated weights.")



    