/*
 * Copyright 2025 NXP
 *
 * SPDX-License-Identifier: Apache-2.0
 */

#include "openthread/platform/infra_if.h"
#include "icmpv6.h"
#include "ipv6.h"
#include "openthread_border_router.h"
#include <common/code_utils.hpp>
#include <platform-zephyr.h>
#include <route.h>
#include <zephyr/kernel.h>
#include <zephyr/net/ethernet.h>
#include <zephyr/net/icmp.h>
#include <zephyr/net/mld.h>
#include <zephyr/net/net_if.h>
#include <zephyr/net/net_ip.h>
#include <zephyr/net/net_pkt.h>
#include <zephyr/net/openthread.h>
#include <zephyr/net/socket.h>
#include <zephyr/net/socket_service.h>
#include <icmpv6.h>

#if defined(CONFIG_OPENTHREAD_NAT64_TRANSLATOR)
#include <zephyr/net/net_pkt_filter.h>
#include <openthread/nat64.h>
#endif /* CONFIG_OPENTHREAD_NAT64_TRANSLATOR */

static struct otInstance *ot_instance;
static struct net_if *ail_iface_ptr;
static uint32_t ail_iface_index;
static struct net_icmp_ctx ra_ctx;
static struct net_icmp_ctx rs_ctx;
static struct net_icmp_ctx na_ctx;
static struct net_in6_addr mcast_addr;

static void infra_if_handle_backbone_icmp6(struct otbr_msg_ctx *msg_ctx_ptr);
static void handle_ra_from_ot(const uint8_t *buffer, uint16_t buffer_length);
#if defined(CONFIG_OPENTHREAD_NAT64_TRANSLATOR)
#define MAX_SERVICES CONFIG_OPENTHREAD_ZEPHYR_BORDER_ROUTER_NAT64_SERVICES

static struct zsock_pollfd sockfd_raw[MAX_SERVICES];
static void raw_receive_handler(struct net_socket_service_event *evt);
static void remove_checksums_for_eth_offloading(uint8_t *buf, uint16_t len);
static bool infra_if_nat64_try_consume_packet(struct npf_test *test, struct net_pkt *pkt);
static int raw_infra_if_sock = -1;

NET_SOCKET_SERVICE_SYNC_DEFINE_STATIC(handle_infra_if_raw_recv, raw_receive_handler, MAX_SERVICES);

struct ot_nat64_pkt_filter_test {
	struct npf_test test;
};
/* Packet filtering rules for NAT64 translator section */
static struct ot_nat64_pkt_filter_test ot_nat64_drop_rule_check = {
	.test.fn = infra_if_nat64_try_consume_packet};
/* Drop all traffic destined to and consumed by NAT64 translator */
static NPF_RULE(ot_nat64_drop_pkt_process, NET_DROP, ot_nat64_drop_rule_check);
#endif /* CONFIG_OPENTHREAD_NAT64_TRANSLATOR */

otError otPlatInfraIfSendIcmp6Nd(uint32_t aInfraIfIndex, const otIp6Address *aDestAddress,
				 const uint8_t *aBuffer, uint16_t aBufferLength)
{
	otError error = OT_ERROR_NONE;
	struct net_pkt *pkt = NULL;
	struct net_in6_addr dst = {0};
	const struct net_in6_addr *src;

	if (aBuffer[0] == NET_ICMPV6_RA) {
		handle_ra_from_ot(aBuffer, aBufferLength);
	}

	memcpy(&dst, aDestAddress, sizeof(otIp6Address));

	src = net_if_ipv6_select_src_addr(ail_iface_ptr, &dst);
	VerifyOrExit(!net_ipv6_is_addr_unspecified(src), error = OT_ERROR_FAILED);

	pkt = net_pkt_alloc_with_buffer(ail_iface_ptr, aBufferLength, NET_AF_INET6,
					NET_IPPROTO_ICMPV6, K_MSEC(100));
	VerifyOrExit(pkt, error = OT_ERROR_FAILED);

	net_pkt_set_ipv6_hop_limit(pkt, NET_IPV6_ND_HOP_LIMIT);

	VerifyOrExit(net_ipv6_create(pkt, src, &dst) == 0, error = OT_ERROR_FAILED);
	VerifyOrExit(net_pkt_write(pkt, aBuffer, aBufferLength) == 0, error = OT_ERROR_FAILED);
	net_pkt_cursor_init(pkt);

	VerifyOrExit(net_ipv6_finalize(pkt, NET_IPPROTO_ICMPV6) == 0, error = OT_ERROR_FAILED);
	VerifyOrExit(net_send_data(pkt) == 0, error = OT_ERROR_FAILED);

exit:
	if (error == OT_ERROR_FAILED) {
		if (pkt != NULL) {
			net_pkt_unref(pkt);
			pkt = NULL;
		}
	}

	return error;
}

otError otPlatInfraIfDiscoverNat64Prefix(uint32_t aInfraIfIndex)
{
	OT_UNUSED_VARIABLE(aInfraIfIndex);

	return OT_ERROR_NOT_IMPLEMENTED;
}

bool otPlatInfraIfHasAddress(uint32_t aInfraIfIndex, const otIp6Address *aAddress)
{
	struct net_if_addr *ifaddr = NULL;
	struct net_in6_addr addr = {0};

	memcpy(addr.s6_addr, aAddress->mFields.m8, sizeof(otIp6Address));

	STRUCT_SECTION_FOREACH(net_if, tmp) {
		if (net_if_get_by_iface(tmp) != aInfraIfIndex) {
			continue;
		}
		ifaddr = net_if_ipv6_addr_lookup_by_iface(tmp, &addr);
		if (ifaddr != NULL) {
			return true;
		} else {
			return false;
		}
	}
	return false;
}

otError otPlatGetInfraIfLinkLayerAddress(otInstance *aInstance, uint32_t aIfIndex,
					 otPlatInfraIfLinkLayerAddress *aInfraIfLinkLayerAddress)
{
	OT_UNUSED_VARIABLE(aInstance);
	struct net_if *iface = net_if_get_by_index(aIfIndex);
	struct net_linkaddr *link_addr = net_if_get_link_addr(iface);

	aInfraIfLinkLayerAddress->mLength = link_addr->len;
	memcpy(aInfraIfLinkLayerAddress->mAddress, link_addr->addr, link_addr->len);

	return OT_ERROR_NONE;
}

otError infra_if_init(otInstance *instance, struct net_if *ail_iface)
{
	otError error = OT_ERROR_NONE;
	int ret;

	ot_instance = instance;
	ail_iface_ptr = ail_iface;
	ail_iface_index = (uint32_t)net_if_get_by_iface(ail_iface_ptr);

	net_ipv6_addr_create_ll_allrouters_mcast(&mcast_addr);
	ret = net_ipv6_mld_join(ail_iface, &mcast_addr);

	VerifyOrExit((ret == 0 || ret == -EALREADY), error = OT_ERROR_FAILED);

	for (uint8_t i = 0; i < MAX_SERVICES; i++) {
		sockfd_raw[i].fd = -1;
	}
exit:
	return error;
}

otError infra_if_deinit(void)
{
	otError error = OT_ERROR_NONE;

	ot_instance = NULL;
	ail_iface_index = 0;

	VerifyOrExit(net_ipv6_mld_leave(ail_iface_ptr, &mcast_addr) == 0,
		     error = OT_ERROR_FAILED);

exit:
	ail_iface_ptr = NULL;

	return error;
}

static void handle_ra_from_ot(const uint8_t *buffer, uint16_t buffer_length)
{
	struct net_if *ot_iface = net_if_get_first_by_type(&NET_L2_GET_NAME(OPENTHREAD));
	struct net_if_ipv6_prefix *prefix_added = NULL;
	struct net_route_entry *route_added = NULL;
	struct net_in6_addr rio_prefix = {0};
	struct net_if_addr *ifaddr = NULL;
	struct net_in6_addr addr_to_add_from_pio = {0};
	struct net_in6_addr nexthop = {0};
	uint8_t i = sizeof(struct net_icmp_hdr) + sizeof(struct net_icmpv6_ra_hdr);

	while (i + sizeof(struct net_icmpv6_nd_opt_hdr) <= buffer_length) {
		const struct net_icmpv6_nd_opt_hdr *opt_hdr =
			(const struct net_icmpv6_nd_opt_hdr *)&buffer[i];

		i += sizeof(struct net_icmpv6_nd_opt_hdr);
		switch (opt_hdr->type) {
		case NET_ICMPV6_ND_OPT_PREFIX_INFO:
			const struct net_icmpv6_nd_opt_prefix_info *pio =
				(const struct net_icmpv6_nd_opt_prefix_info *)&buffer[i];
			prefix_added = net_if_ipv6_prefix_add(ail_iface_ptr,
							      (struct net_in6_addr *)pio->prefix,
							      pio->prefix_len, pio->valid_lifetime);
			i += sizeof(struct net_icmpv6_nd_opt_prefix_info);
			net_ipv6_addr_generate_iid(
				ail_iface_ptr, (struct net_in6_addr *)pio->prefix,
				COND_CODE_1(CONFIG_NET_IPV6_IID_STABLE,
				     ((uint8_t *)&ail_iface_ptr->config.ip.ipv6->network_counter),
				     (NULL)), COND_CODE_1(CONFIG_NET_IPV6_IID_STABLE,
				     (sizeof(ail_iface_ptr->config.ip.ipv6->network_counter)),
				     (0U)), 0U, &addr_to_add_from_pio,
							       net_if_get_link_addr(ail_iface_ptr));
			ifaddr = net_if_ipv6_addr_lookup(&addr_to_add_from_pio, NULL);
			if (ifaddr != NULL) {
				net_if_addr_set_lf(ifaddr, true);
			}
			net_if_ipv6_addr_add(ail_iface_ptr, &addr_to_add_from_pio,
					     NET_ADDR_AUTOCONF, pio->valid_lifetime);
			break;
		case NET_ICMPV6_ND_OPT_ROUTE:
			uint8_t prefix_field_len = (opt_hdr->len - 1) * 8;
			const otIp6Address *br_omr_addr = get_ot_slaac_address(ot_instance);

			const struct net_icmpv6_nd_opt_route_info *rio =
				(const struct net_icmpv6_nd_opt_route_info *)&buffer[i];

			i += sizeof(struct net_icmpv6_nd_opt_route_info);
			memcpy(&rio_prefix.s6_addr, &buffer[i], prefix_field_len);
			if (!otIp6IsAddressUnspecified(br_omr_addr)) {
				memcpy(&nexthop.s6_addr, br_omr_addr->mFields.m8,
				       sizeof(br_omr_addr->mFields.m8));
				net_ipv6_nbr_add(ot_iface, &nexthop, net_if_get_link_addr(ot_iface),
						 false, NET_IPV6_NBR_STATE_STALE);
				route_added = net_route_add(ot_iface, &rio_prefix, rio->prefix_len,
							    &nexthop, rio->route_lifetime,
							    rio->flags.prf);
			}
			break;
		default:
			break;
		}
	}
}

static int handle_icmp6_input(struct net_icmp_ctx *ctx, struct net_pkt *pkt,
			      struct net_icmp_ip_hdr *hdr,
			      struct net_icmp_hdr *icmp_hdr, void *user_data)
{
	uint16_t length = net_pkt_get_len(pkt);
	struct otbr_msg_ctx *req = NULL;
	otError error = OT_ERROR_NONE;

	VerifyOrExit(openthread_border_router_allocate_message((void **)&req) == OT_ERROR_NONE,
		     error = OT_ERROR_FAILED);

	if (net_buf_linearize(req->buffer, sizeof(req->buffer),
			      pkt->buffer, 0, length) == length) {
		req->length = length;
		memcpy(&req->addr, hdr->ipv6->src, sizeof(req->addr));
		req->cb = infra_if_handle_backbone_icmp6;

		openthread_border_router_post_message(req);
	} else {
		openthread_border_router_deallocate_message((void *)req);
		ExitNow(error = OT_ERROR_FAILED);
	}

exit:
	if (error == OT_ERROR_NONE) {
		return 0;
	}

	return -1;
}

static void infra_if_handle_backbone_icmp6(struct otbr_msg_ctx *msg_ctx_ptr)
{
	otPlatInfraIfRecvIcmp6Nd(
		ot_instance, ail_iface_index, &msg_ctx_ptr->addr,
		(const uint8_t *)&msg_ctx_ptr->buffer[sizeof(struct net_ipv6_hdr)],
		msg_ctx_ptr->length);
}

otError infra_if_start_icmp6_listener(void)
{
	otError error = OT_ERROR_NONE;

	VerifyOrExit(net_icmp_init_ctx(&ra_ctx, NET_AF_INET6, NET_ICMPV6_RA, 0,
				       handle_icmp6_input) == 0,
		     error = OT_ERROR_FAILED);
	VerifyOrExit(net_icmp_init_ctx(&rs_ctx, NET_AF_INET6, NET_ICMPV6_RS, 0,
				       handle_icmp6_input) == 0,
		     error = OT_ERROR_FAILED);
	VerifyOrExit(net_icmp_init_ctx(&na_ctx, NET_AF_INET6, NET_ICMPV6_NA, 0,
				       handle_icmp6_input) == 0,
		     error = OT_ERROR_FAILED);

exit:
	return error;
}

void infra_if_stop_icmp6_listener(void)
{
	(void)net_icmp_cleanup_ctx(&ra_ctx);
	(void)net_icmp_cleanup_ctx(&rs_ctx);
	(void)net_icmp_cleanup_ctx(&na_ctx);
}

#if defined(CONFIG_OPENTHREAD_NAT64_TRANSLATOR)
otError infra_if_nat64_init(void)
{
	otError error = OT_ERROR_NONE;
	struct net_sockaddr_in anyaddr = {.sin_family = NET_AF_INET,
					  .sin_port = 0,
					  .sin_addr = NET_INADDR_ANY_INIT};

	raw_infra_if_sock = zsock_socket(NET_AF_INET, NET_SOCK_RAW, NET_IPPROTO_IP);
	VerifyOrExit(raw_infra_if_sock >= 0, error = OT_ERROR_FAILED);
	VerifyOrExit(zsock_bind(raw_infra_if_sock, (struct net_sockaddr *)&anyaddr,
				sizeof(struct net_sockaddr_in)) == 0,
		     error = OT_ERROR_FAILED);

	sockfd_raw[0].fd = raw_infra_if_sock;
	sockfd_raw[0].events = ZSOCK_POLLIN;

	VerifyOrExit(net_socket_service_register(&handle_infra_if_raw_recv, sockfd_raw,
						 ARRAY_SIZE(sockfd_raw), NULL) == 0,
		     error = OT_ERROR_FAILED);

	npf_insert_ipv4_recv_rule(&ot_nat64_drop_pkt_process);
	npf_append_ipv4_recv_rule(&npf_default_ok);

exit:
	return error;
}

static void raw_receive_handler(struct net_socket_service_event *evt)
{
	int len;
	struct net_pkt *ot_pkt = NULL;
	struct otbr_msg_ctx *req = NULL;
	otError error = OT_ERROR_NONE;

	VerifyOrExit(openthread_border_router_allocate_message((void **)&req) == OT_ERROR_NONE,
		     error = OT_ERROR_FAILED);
	VerifyOrExit(evt->event.revents & ZSOCK_POLLIN);

	len = zsock_recv(raw_infra_if_sock, req->buffer, sizeof(req->buffer), 0);
	VerifyOrExit(len >= 0, error = OT_ERROR_FAILED);

	ot_pkt = net_pkt_alloc_with_buffer(ail_iface_ptr, len, NET_AF_INET, 0, K_NO_WAIT);
	VerifyOrExit(ot_pkt != NULL, error = OT_ERROR_FAILED);

	VerifyOrExit(net_pkt_write(ot_pkt, req->buffer, len) == 0, error = OT_ERROR_FAILED);

	openthread_border_router_deallocate_message((void *)req);
	req = NULL;

	VerifyOrExit(notify_new_tx_frame(ot_pkt) == 0, error = OT_ERROR_FAILED);

exit:
	if (error != OT_ERROR_NONE) {
		if (ot_pkt != NULL) {
			net_pkt_unref(ot_pkt);
		}
		if (req != NULL) {
			openthread_border_router_deallocate_message((void *)req);
		}
	}
}

otError infra_if_send_raw_message(uint8_t *buf, uint16_t len)
{
	otError error = OT_ERROR_NONE;

	remove_checksums_for_eth_offloading(buf, len);

	VerifyOrExit(zsock_send(raw_infra_if_sock, buf, len, 0) > 0,
		     error = OT_ERROR_FAILED);

exit:
	return error;
}

static void remove_checksums_for_eth_offloading(uint8_t *buf, uint16_t len)
{
	struct net_ipv4_hdr *ipv4_hdr = (struct net_ipv4_hdr *)buf;
	uint8_t *pkt_cursor = NULL;
	struct ethernet_config config;

	if ((net_eth_get_hw_capabilities(ail_iface_ptr) & ETHERNET_HW_TX_CHKSUM_OFFLOAD) == 0) {
		return; /* No checksum offload capabilities*/
	}

	if (net_eth_get_hw_config(ail_iface_ptr, ETHERNET_CONFIG_TYPE_TX_CHECKSUM_SUPPORT,
				  &config) != 0) {
		return; /* No TX checksum capabilities*/
	}

	pkt_cursor = buf + (ipv4_hdr->vhl & 0x0F) * 4;

	if ((config.chksum_support & NET_IF_CHECKSUM_IPV4_HEADER) != 0) {
		ipv4_hdr->chksum = 0;
	}

	switch (ipv4_hdr->proto) {
	case NET_IPPROTO_ICMP:
		if ((config.chksum_support & NET_IF_CHECKSUM_IPV4_ICMP) != 0) {
			struct net_icmp_hdr *icmp_hdr = (struct net_icmp_hdr *)pkt_cursor;

			icmp_hdr->chksum = 0;
		}
	break;
	case NET_IPPROTO_UDP:
		if ((config.chksum_support & NET_IF_CHECKSUM_IPV4_UDP) != 0) {
			struct net_udp_hdr *udp_hdr = (struct net_udp_hdr *)pkt_cursor;

			udp_hdr->chksum = 0;
		}
	break;
	case NET_IPPROTO_TCP:
		if ((config.chksum_support & NET_IF_CHECKSUM_IPV4_TCP) != 0) {
			struct net_tcp_hdr *tcp_hdr = (struct net_tcp_hdr *)pkt_cursor;

			tcp_hdr->chksum = 0;
		}
	break;
	default:
		break;
	}
}

static bool infra_if_nat64_try_consume_packet(struct npf_test *test, struct net_pkt *pkt)
{
	ARG_UNUSED(test);

	struct net_buf *buf = NULL;
	otMessage *message = NULL;
	otMessageSettings settings;

	openthread_mutex_lock();

	if (ot_instance == NULL ||
	    otNat64GetTranslatorState(ot_instance) != OT_NAT64_STATE_ACTIVE) {
		ExitNow();
	}

	settings.mPriority = OT_MESSAGE_PRIORITY_NORMAL;
	settings.mLinkSecurityEnabled = true;

	message = otIp4NewMessage(ot_instance, &settings);
	VerifyOrExit(message != NULL);

	for (buf = pkt->buffer; buf; buf = buf->frags) {
		if (otMessageAppend(message, buf->data, buf->len) != OT_ERROR_NONE) {
			otMessageFree(message);
			ExitNow();
		}
	}

	if (otNat64Send(ot_instance, message) == OT_ERROR_NONE) {
		net_pkt_unref(pkt);
		openthread_mutex_unlock();
		return true;
	}

exit:
	openthread_mutex_unlock();
	return false;
}

#endif /* CONFIG_OPENTHREAD_NAT64_TRANSLATOR */
