#include <lib/seg/mask_data.h>
#include <lib/seg/util.h>

#include <stdio.h>

// Allocate Mask Data for Label
MaskData* create_mask_data(MaskData_t label)
{
  MaskData *data = (MaskData*)malloc(sizeof(MaskData));
  data->label = label;
  data->area = 0;
  data->perimeter = 0;
  return data;
}

// Compare mask data labels
bool_t compare_labels(MaskData* left, MaskData* right)
{
  return left->label < right->label;
}

// Create AVL Mask node
AVLNode* create_avl_mask_node(MaskData* data)
{
  AVLNode* node = (AVLNode*)malloc(sizeof(AVLNode));
  if (node == NULL) {
    return NULL;
  }
  node->data = data;
  node->compare = (bool_t (*)(void*,void*))&compare_labels;
  node->left = NULL;
  node->right = NULL;
  node->height = 1; // Leaf initially
  return node;
}

// Insert MaskData into the AVL Tree
Result insert_mask(AVLNode* node, MaskData* data)
{
  Result result;
  // 1. Standard BST insertion
  if (node == NULL) {
    return (Result) {create_avl_mask_node(data), TRUE};
  }

  MaskData *node_data = (MaskData*)node->data;
  if (node->compare(data, node_data)) {
    result = insert_mask(node->left, data);
    if (!result.success) {
      fprintf(stderr, "Failed to insert!");
      return result;
    }
    node->left = (AVLNode*)result.data;
  } else if (node->compare(node->data, data)) {
    result = insert_mask(node->right, data);
    if (!result.success) {
      fprintf(stderr, "Failed to insert!");
      return result;
    }
    node->right = (AVLNode*)result.data;
  } else {
    return (Result) {node, FALSE};
  }

  // 2. Update height of the ancestor node
  node->height = 1 + max_height(get_height(node->left), get_height(node->right));

  ssize_t balance = get_balance_factor(node);

  // 4. If the node becomes unbalanced

  // LeftLeft
  if ((balance > 1) && node->compare(data, node->left->data)) {
    return (Result) {right_rotate(node), TRUE};
  }
  // RightRight
  if ((balance < -1) && node->compare(node->right->data, data)) {
    return (Result) {left_rotate(node), TRUE};
  }
  // LeftRight
  if ((balance > 1) && node->compare(node->left->data, data)) {
    return (Result) {right_rotate(node), TRUE};
  }
  // RightLeft
  if ((balance < -1) && node->compare(data,node->right->data)) {
    return (Result) {left_rotate(node), TRUE};
  }
  return (Result) {node, TRUE};
}

// Allocate a label's Mask data in a tree
//  If it already exists, skip the allocation
AVLNode* insert_mask_alloc(AVLNode* node, MaskData_t label)
{
  MaskData* data = create_mask_data(label);
  Result result = insert_mask(node, data);
  if (!result.success) {
    free(data);
  }
  return (AVLNode*)result.data;
}

// Print AVL Node Mask Data Label
void print_label(AVLNode* root)
{
  if (root != NULL) {
    print_label(root->left);
    MaskData* data = root->data;
    printf("%d: (%zu, %zu) ", data->label, data->area, data->perimeter);
    print_label(root->right);
  }
}

// Increase the label's area
bool_t increase_label_area(AVLNode* root, MaskData_t label)
{
  if (root == NULL) {
    return FALSE;
  }
  MaskData* data = (MaskData*)root->data;
  if (data->label == label) {
    data->area++;
  }
  else if (data->label > label) {
    return increase_label_area(root->left, label);
  }
  else if (data->label < label) {
    return increase_label_area(root->right, label);
  }
  return TRUE;
}

// Increase the label's perimeter
bool_t increase_label_perimeter(AVLNode* root, MaskData_t label)
{
  if (root == NULL) {
    return FALSE;
  }
  MaskData* data = (MaskData*)root->data;
  if (data->label == label) {
    data->perimeter++;
  }
  else if (data->label > label) {
    return increase_label_perimeter(root->left, label);
  }
  else if (data->label < label) {
    return increase_label_perimeter(root->right, label);
  }
  return TRUE;
}

// Increase the label's area
//  Create an AVL node if it doesn't exist
AVLNode* increase_label_area_alloc(AVLNode* root, MaskData_t label)
{
  AVLNode* new_root = root;
  bool_t success = increase_label_area(new_root, label);
  if (success == FALSE) {
    new_root = insert_mask_alloc(new_root, label);
    increase_label_area(new_root, label);
  }
  return new_root;
}

// Increase the label's perimeter
//  Create an AVL node if it doesn't exist
AVLNode* increase_label_perimeter_alloc(AVLNode* root, MaskData_t label)
{
  AVLNode* new_root = root;
  bool_t success = increase_label_perimeter(new_root, label);
  if (success == FALSE) {
    new_root = insert_mask_alloc(new_root, label);
    increase_label_perimeter(new_root, label);
  }
  return new_root;
}

// Comparison of MaskData_ts
bool_t compare_image_mask_data_t(MaskData_t* s1, MaskData_t* s2)
{
  return *s1 < *s2;
}

// In-order traversal print pointer
void print_in_order_image_mask_data_t(AVLNode* root)
{
  if (root != NULL) {
    print_in_order_image_mask_data_t(root->left);
    printf("%d ", *((MaskData_t*)root->data));
    print_in_order_image_mask_data_t(root->right);
  }
}

// Check if MaskData_t in AVLTree with MaskData_t* data
bool_t in_image_mask_data_t_tree(AVLNode* root, MaskData_t value)
{
  if (root == NULL) {
    return FALSE;
  }
  if (*((MaskData_t*)root->data) == value) {
    return TRUE;
  } else if (value < *((MaskData_t*)root->data)) {
    return in_image_mask_data_t_tree(root->left, value);
  } else {
    return in_image_mask_data_t_tree(root->right, value);
  }
}

// Filter out small masks
//  Assumption: Contiguous labeling
AVLNode* get_small_labels(AVLNode* removal_tree, AVLNode* label_tree, size_t min_area, size_t min_perimeter)
{
  AVLNode* return_tree = removal_tree;
  if (label_tree != NULL) {
    return_tree = get_small_labels(return_tree, label_tree->left, min_area, min_perimeter);
    MaskData* node_data = (MaskData*)label_tree->data;
    if ((node_data->area < min_area) || (node_data->perimeter < min_perimeter)) {
      // Insert
      Result result = avl_insert(return_tree, &node_data->label, (bool_t (*)(void*,void*))compare_image_mask_data_t);
      if (result.success) {
	return_tree = result.data;
      }
    }
    return_tree = get_small_labels(return_tree, label_tree->right, min_area, min_perimeter);
  }
  return return_tree;
}

// Get mask label data
AVLNode *get_mask_data(Mask *mask)
{
  uint32_t width = mask->width;
  uint32_t height = mask->height;
  AVLNode* root = NULL;
  for (size_t y = 0; y < height; y++) {
    for (size_t x = 0; x < width; x++) {
      if (mask->image[y][x] != 0) {
	root = increase_label_area_alloc(root, mask->image[y][x]);
	if (is_on_mask_boundary(mask, x, y)) {
	  increase_label_perimeter(root, mask->image[y][x]);
	}
      }
    }
  }
  return root;
}

// Filter out small masks in mask
void filter_small_masks(Mask *mask, size_t min_area, size_t min_perimeter)
{
  uint32_t width = mask->width;
  uint32_t height = mask->height;
  AVLNode* root = get_mask_data(mask);
  AVLNode* small_label_tree = NULL;
  small_label_tree = get_small_labels(NULL, root, min_area, min_perimeter);
  for (size_t y = 0; y < height; y++) {
    for (size_t x = 0; x < width; x++) {
      if (in_image_mask_data_t_tree(small_label_tree, mask->image[y][x])) {
	mask->image[y][x] = 0; 
      }
    }
  }
  free_avl_tree(small_label_tree);
  free_avl_tree_nodes(root);
}