#!/usr/bin/env python3

import argparse
import difflib
import os
import re
import shutil
import subprocess
import sys
import tempfile


VERSION = "clang-format-radare2 1.0"
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.abspath(os.path.join(SCRIPT_DIR, os.pardir))
AUTO_WHITELIST_PATH = os.path.join(SCRIPT_DIR, "auto-format-files.txt")

CLANG_FORMAT_CONFIG = """BasedOnStyle: LLVM
Language: Cpp
PointerAlignment: Right
AlwaysBreakAfterDefinitionReturnType: None
BinPackParameters: false
BinPackArguments: false
MaxEmptyLinesToKeep: 1
SpaceInEmptyParentheses: false
SpacesInContainerLiterals: true
SpaceBeforeParens: Custom
SpaceBeforeParensOptions:
  AfterIfMacros: true
  AfterFunctionDefinitionName: false
  AfterFunctionDeclarationName: false
  AfterForeachMacros: true
  AfterControlStatements: true
  BeforeNonEmptyParentheses: false
SpacesInParentheses: false
InsertBraces: true
ContinuationIndentWidth: 8
IndentCaseLabels: false
IndentFunctionDeclarationAfterType: false
IndentWidth: 8
UseTab: ForContinuationAndIndentation
ColumnLimit: 0
BreakBeforeBraces: Attach
BreakBeforeTernaryOperators: true
SpaceBeforeSquareBrackets: false
SpaceInEmptyBlock: false
AllowShortIfStatementsOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: true
AllowShortFunctionsOnASingleLine: Inline
AllowShortLoopsOnASingleLine: false
AlignAfterOpenBracket: DontAlign
AlignEscapedNewlines: DontAlign
AlignConsecutiveMacros: true
AlignTrailingComments: false
AlignOperands: false
Cpp11BracedListStyle: false
ForEachMacros: ['r_list_foreach', 'ls_foreach', 'fcn_tree_foreach_intersect', 'r_skiplist_foreach', 'graph_foreach_anode', 'r_list_foreach_safe', 'R_VEC_FOREACH', 'R_VEC_FOREACH_I', 'R_VEC_FOREACH_PREV', 'r_rbtree_foreach', 'r_interval_tree_foreach']
SortIncludes: false
"""


def parse_args():
	parser = argparse.ArgumentParser(
		description="Format files using clang-format followed by radare2 indentation rules."
	)
	parser.add_argument(
		"-i",
		"--in-place",
		dest="in_place",
		action="store_true",
		help="Do nothing, because that's the default behaviour, for muscle memory compatibility reasons with clang-format",
	)
	parser.add_argument(
		"-n",
		"--no-update",
		dest="no_update",
		action="store_true",
		help="Do not modify files; report differences and exit with error if formatting is needed.",
	)
	parser.add_argument(
		"-v",
		"--version",
		action="store_true",
		help="Show the clang-format-radare2 version and exit.",
	)
	parser.add_argument(
		"-a",
		"--auto",
		action="store_true",
		help="Auto mode: ignore provided paths and use the indentation whitelist.",
	)
	parser.add_argument(
		"files",
		nargs="*",
		help="Files or directories to format in place.",
	)
	parser.add_argument(
		"--clang-format",
		dest="clang_format",
		default=os.environ.get("CLANG_FORMAT", "clang-format"),
		help="clang-format executable to use (default: env CLANG_FORMAT or clang-format).",
	)
	parser.add_argument(
		"--print-config",
		action="store_true",
		help="Print the embedded .clang-format configuration to stdout and exit.",
	)
	return parser.parse_args()


def is_source_file(path):
	ext = os.path.splitext(path)[1].lower()
	return ext in (".c", ".h")


def expand_targets(paths):
	files = []
	seen = set()
	for target in paths:
		if not target or target in seen:
			continue
		if os.path.isdir(target):
			for root, _, filenames in os.walk(target):
				for name in filenames:
					if not is_source_file(name):
						continue
					full = os.path.join(root, name)
					if full in seen:
						continue
					files.append(full)
					seen.add(full)
		else:
			files.append(target)
			seen.add(target)
	return files


def load_auto_files():
	paths = []
	try:
		with open(AUTO_WHITELIST_PATH, "r", encoding="utf-8") as fd:
			for line in fd:
				stripped = line.strip()
				if not stripped or stripped.startswith("#"):
					continue
				paths.append(stripped)
	except OSError:
		return []
	result = []
	seen = set()
	for entry in paths:
		path = entry if os.path.isabs(entry) else os.path.join(PROJECT_ROOT, entry)
		if path in seen or not os.path.isfile(path):
			continue
		result.append(path)
		seen.add(path)
	return result


def show_diff(path, original, formatted):
	try:
		display_path = os.path.relpath(path)
	except ValueError:
		display_path = path
	for line in difflib.unified_diff(
		original.splitlines(True),
		formatted.splitlines(True),
		fromfile=display_path,
		tofile=f"{display_path} (formatted)",
	):
		sys.stdout.write(line)


def convert_leading_spaces(line):
	match = re.match(r"^([ \t]+)", line)
	if not match:
		return line
	ws = match.group(1)
	width = 0
	for ch in ws:
		if ch == ' ':
			width += 1
		elif ch == '\t':
			width += 8
	tabs = "\t" * (width // 8)
	return tabs + line[len(ws):]


def rewrite_outside_literals(line, handler):
	in_single = False
	in_double = False
	escape = False
	result = []
	i = 0
	length = len(line)
	while i < length:
		ch = line[i]
		if escape:
			result.append(ch)
			escape = False
			i += 1
			continue
		if ch == "\\" and (in_single or in_double):
			result.append(ch)
			escape = True
			i += 1
			continue
		if ch == "'" and not in_double:
			in_single = not in_single
			result.append(ch)
			i += 1
			continue
		if ch == '"' and not in_single:
			in_double = not in_double
			result.append(ch)
			i += 1
			continue
		if not in_single and not in_double:
			next_i = handler(line, result, i)
			if next_i is not None:
				i = next_i
				continue
		result.append(ch)
		i += 1
	return "".join(result)


LEADING_WS_RE = re.compile(r"^(\s*)")


def leading_ws(line):
	match = LEADING_WS_RE.match(line)
	return match.group(1) if match else ""


def fix_paren_spacing(line):
	line_starts_alnum = bool(line and re.match(r"[A-Za-z0-9]", line))

	def handle_paren(line, result, i):
		if line[i] != "(":
			return None
		while result and result[-1] == " ":
			result.pop()
		prev = result[-1] if result else ""
		if (
			not line_starts_alnum
			and prev not in ("", "(", "\t", "\n", "*", "_", "&", "[")
			and not line.startswith("#")
		):
			result.append(" ")
		result.append("(")
		return i + 1

	return rewrite_outside_literals(line, handle_paren)


def split_line_ending(line):
	if line.endswith("\r\n"):
		return line[:-2], "\r\n"
	if line.endswith("\n"):
		return line[:-1], "\n"
	return line, ""


CASE_OPEN_RE = re.compile(r"^(\s*(?:case\b.*:|default:))\s*\{(\s*//.*)?$")
CASE_CLOSE_BREAK_RE = re.compile(r"^(\s*)\}\s*break;(.*)$")
CASE_CLOSE_ONLY_RE = re.compile(r"^(\s*)\}(.*)$")
LABEL_RE = re.compile(r"^(\s*)([A-Za-z_][A-Za-z0-9_]*):(\s*(?://.*)?)?$")


def fix_case_blocks(lines):
	fixed = []
	case_indent_stack = []
	for line in lines:
		content, ending = split_line_ending(line)
		line_sep = ending or "\n"
		match_open = CASE_OPEN_RE.match(content)
		if match_open:
			case_line = match_open.group(1).rstrip()
			comment = match_open.group(2)
			if comment:
				case_line = f"{case_line} {comment.strip()}"
			indent = leading_ws(content)
			fixed.append(case_line + line_sep)
			fixed.append(indent + "\t{" + line_sep)
			case_indent_stack.append(indent)
			continue
		match_close = CASE_CLOSE_BREAK_RE.match(content)
		if match_close:
			comment = match_close.group(2).strip()
			if case_indent_stack:
				case_indent = case_indent_stack.pop()
				inner_indent = case_indent + "\t"
			else:
				inner_indent = match_close.group(1)
			fixed.append(inner_indent + "}" + line_sep)
			if comment:
				fixed.append(inner_indent + "break; " + comment + line_sep)
			else:
				fixed.append(inner_indent + "break;" + line_sep)
			continue
		match_close_only = CASE_CLOSE_ONLY_RE.match(content)
		if match_close_only and case_indent_stack:
			indent = match_close_only.group(1)
			case_indent = case_indent_stack[-1]
			inner_indent = case_indent + "\t"
			if indent == case_indent:
				case_indent_stack.pop()
				remainder = match_close_only.group(2).strip()
				if remainder:
					fixed.append(inner_indent + "}" + " " + remainder + line_sep)
				else:
					fixed.append(inner_indent + "}" + line_sep)
				continue
		if case_indent_stack:
			case_indent = case_indent_stack[-1]
			inner_indent = case_indent + "\t"
			body_indent = inner_indent + "\t"
			stripped = content.lstrip()
			if content and not stripped.startswith(("{", "||", "&&", "?", ":")):
				if content.startswith(inner_indent):
					adjusted = body_indent + content[len(inner_indent):]
					fixed.append(adjusted + line_sep)
					continue
		fixed.append(line)
	return fixed


def fix_labels(lines):
	fixed = []
	for line in lines:
		content, ending = split_line_ending(line)
		line_sep = ending or "\n"
		match = LABEL_RE.match(content)
		if match:
			label_name = match.group(2)
			# Skip case and default labels - they're switch labels, not goto labels
			if label_name in ("case", "default"):
				fixed.append(line)
				continue
			comment = match.group(3) or ""
			fixed.append(label_name + ":" + comment + line_sep)
			continue
		fixed.append(line)
	return fixed


def fix_ternary_spacing(line):
	question_depth = 0

	def handle_ternary(line, result, i):
		nonlocal question_depth
		ch = line[i]
		if ch == "?":
			question_depth += 1
		elif ch == ":" and question_depth > 0:
			question_depth -= 1
		else:
			return None
		while result and result[-1] == " ":
			result.pop()
		result.append(ch)
		i += 1
		while i < len(line) and line[i] in " \t":
			i += 1
		forbidden = (" ", "\t", "\n", ")", ",", ";") + ((":",) if ch == "?" else ())
		if i < len(line) and line[i] not in forbidden:
			result.append(" ")
		return i

	return rewrite_outside_literals(line, handle_ternary)


def fix_hash_spacing(line):

	def handle_hash(line, result, i):
		if i + 1 >= len(line) or line[i] != "#" or line[i + 1] != "#":
			return None
		while result and result[-1] == " ":
			result.pop()
		result.append(" ")
		result.append("##")
		result.append(" ")
		i += 2
		while i < len(line) and line[i] == " ":
			i += 1
		return i

	return rewrite_outside_literals(line, handle_hash)


def fix_logical_spacing(line):

	def handle_logical(line, result, i):
		pair = line[i:i + 2]
		if pair not in ("&&", "||"):
			return None
		result.append(pair)
		i += 2
		if i < len(line) and line[i] not in (" ", "\t", "\n"):
			result.append(" ")
		return i

	return rewrite_outside_literals(line, handle_logical)


def fix_ampersand_paren_spacing(line):

	def handle_ampersand(line, result, i):
		if line[i] != "&":
			return None
		if i + 1 < len(line) and line[i + 1] in ("&", "="):
			return None
		j = i + 1
		while j < len(line) and line[j] in (" ", "\t"):
			j += 1
		if j >= len(line) or line[j] != "(":
			return None
		k = len(result) - 1
		while k >= 0 and result[k] in (" ", "\t"):
			k -= 1
		prev = result[k] if k >= 0 else ""
		if not prev or not (prev.isalnum() or prev in ("_", ")", "]", '"', "'")):
			return None
		while result and result[-1] == " ":
			result.pop()
		if result and result[-1] not in (" ", "\t", "\n"):
			result.append(" ")
		result.append("&")
		result.append(" ")
		return j

	return rewrite_outside_literals(line, handle_ampersand)


def format_macro_body(body):
	if not body:
		return body
	formatted = fix_paren_spacing(body)
	formatted = fix_ampersand_paren_spacing(formatted)
	formatted = fix_ternary_spacing(formatted)
	formatted = fix_hash_spacing(formatted)
	formatted = fix_logical_spacing(formatted)
	return formatted


def normalize_define_line(line):
	rest = line[len("#define"):].lstrip()
	if not rest:
		return "#define"
	i = 0
	while i < len(rest) and (rest[i].isalnum() or rest[i] == '_'):
		i += 1
	name = rest[:i]
	if not name:
		return line
	j = i
	while j < len(rest) and rest[j] == ' ':
		j += 1
	header = f"#define {name}"
	body = ""
	rest_after_name = rest[i:]
	if rest_after_name.startswith('('):
		params_start = i
		depth = 0
		k = params_start
		while k < len(rest):
			ch = rest[k]
			if ch == '(':
				depth += 1
			elif ch == ')':
				depth -= 1
				if depth == 0:
					k += 1
					break
			k += 1
		if depth != 0:
			return line
		params = rest[params_start:k]
		header = f"{header}{params}"
		body = rest[k:].lstrip()
	else:
		while j < len(rest) and rest[j] in (' ', '\t'):
			j += 1
		body = rest[j:].lstrip()
	if not body:
		return header
	body = format_macro_body(body)
	return f"{header} {body}"


def fix_preprocessor_line(line):
	content, ending = split_line_ending(line)
	stripped = content.lstrip()
	if not stripped.startswith('#'):
		return line
	adjusted = stripped.replace('\t', ' ')
	if adjusted.startswith('#define'):
		adjusted = normalize_define_line(adjusted)
	adjusted = fix_hash_spacing(adjusted)
	return adjusted + ending


def fix_multiline_comments(lines):
	result = []
	in_comment = False
	for line in lines:
		content, ending = split_line_ending(line)
		stripped = content.strip()
		if '/*' in stripped and not in_comment:
			in_comment = True
		if in_comment and re.match(r'^\s*\*', content):
			# Ensure exactly one space before *
			match = re.match(r'^(\s*)\*(.*)$', content)
			if match:
				indent = match.group(1)
				rest = match.group(2)
				if indent and indent[-1] != ' ':
					indent += ' '
				elif not indent:
					indent = ' '
				content = indent + '*' + rest
		if '*/' in stripped and in_comment:
			in_comment = False
		result.append(content + ending)
	return result


def fix_brace_newlines(text):
	# Split "enum/struct/union Name { content" into separate lines
	def split_open_brace(m):
		indent = m.group(1)
		keyword = m.group(2)
		name = m.group(3) or ''
		rest = m.group(4).strip()
		result = f"{indent}{keyword}{name} {{\n"
		if rest:
			# Check if rest contains closing brace
			trail = re.match(r'^(.+?)\s*(\}\s*\w*\s*;?)$', rest)
			if trail:
				elem = trail.group(1).strip()
				brace = trail.group(2).strip()
				if elem:
					result += f"{indent}\t{elem}\n"
				result += brace
			else:
				result += f"{indent}\t{rest}"
		return result
	text = re.sub(
		r'^(\s*)((?:typedef\s+)?(?:enum|struct|union))(\s+\w+)?[ \t]*\{[ \t]*(.+)$',
		split_open_brace, text, flags=re.MULTILINE
	)
	# Split "content } Name;" or "content };" putting brace on new line
	# Only matches lines with actual content (non-whitespace) before the closing brace
	def split_close_brace(m):
		content_indent = m.group(1)
		before = m.group(2)
		brace_line = m.group(3) + m.group(4)
		before_stripped = before.strip()
		# Check if 'before' contains actual content (not just whitespace)
		if not before_stripped:
			return m.group(0)  # Just whitespace before }, keep as-is
		if before_stripped.startswith('}'):
			return m.group(0)  # Already starts with }, keep as-is
		# brace_line contains the whitespace (including newline) + closing brace
		return f"{content_indent}{before_stripped}{brace_line}"
	text = re.sub(
		r'^(\s*)([^{}\n]+?)(\s*)(\}\s*\w*\s*;)$',
		split_close_brace, text, flags=re.MULTILINE
	)
	# Move standalone }; to column 0 (only after enum/struct, detected by preceding line pattern)
	lines = text.split('\n')
	result = []
	for i, line in enumerate(lines):
		stripped = line.strip()
		if stripped == '};' and i > 0:
			prev = result[-1].strip() if result else ''
			if prev.endswith(',') or prev.endswith('{'):
				result.append('};')
				continue
		result.append(line)
	return '\n'.join(result)


def is_string_literal_line(content):
	stripped = content.lstrip()
	if not stripped:
		return False
	return stripped.startswith(('"', "u8\"", "u\"", "U\"", "L\""))


def fix_string_continuation_indent(lines):
	fixed = []
	run_base_indent = None
	prev_nonempty = ""
	prev_nonempty_is_string = False
	for line in lines:
		content, ending = split_line_ending(line)
		line_sep = ending or "\n"
		stripped = content.lstrip()
		is_string = is_string_literal_line(content)
		if is_string and prev_nonempty and not prev_nonempty.lstrip().startswith('#'):
			if prev_nonempty_is_string and run_base_indent is not None:
				base_indent = run_base_indent
			else:
				base_indent = leading_ws(prev_nonempty)
				run_base_indent = base_indent
			content = base_indent + "\t" + stripped
			line = content + line_sep
		else:
			line = content + line_sep
			if stripped and not is_string:
				run_base_indent = None
		fixed.append(line)
		if stripped:
			prev_nonempty = content
			prev_nonempty_is_string = is_string
	return fixed


def fix_ternary_lines(lines):
	fixed = []
	i = 0
	while i < len(lines):
		line = lines[i]
		content, ending = split_line_ending(line)
		# Check if line has ? and the ? is not followed by : in the same line
		if '?' in content and not re.search(r'\?\s*[^:]*$', content):
			# If next line starts with spaces and :
			if i + 1 < len(lines):
				next_line = lines[i + 1]
				next_content, next_ending = split_line_ending(next_line)
				if next_content.strip().startswith(':'):
					# Find the last ? in content
					last_q = content.rfind('?')
					if last_q != -1:
						before_q = content[:last_q]
						after_q = content[last_q:]
						# The after_q is ? ... 
						# The next_content is spaces : ...
						# Get the indent of next_content
						indent_match = re.match(r'^(\s*)', next_content)
						indent = indent_match.group(1) if indent_match else ''
						# New line 1: before_q + next_content.strip()
						new_line1 = before_q + next_content.strip() + ending
						# New line 2: indent + after_q + next_ending
						new_line2 = indent + after_q + next_ending
						fixed.append(new_line1)
						fixed.append(new_line2)
						i += 2
						continue
		fixed.append(line)
		i += 1
	return fixed


def apply_indent_rules(text):
	lines = []
	for line in text.splitlines(True):
		if line.lstrip().startswith('#'):
			lines.append(fix_preprocessor_line(line))
			continue
		line = convert_leading_spaces(line)
		line = fix_paren_spacing(line)
		line = fix_ampersand_paren_spacing(line)
		line = fix_ternary_spacing(line)
		line = fix_hash_spacing(line)
		line = fix_logical_spacing(line)
		lines.append(line)
	lines = fix_case_blocks(lines)
	lines = fix_labels(lines)
	lines = fix_ternary_lines(lines)
	lines = fix_string_continuation_indent(lines)
	lines = fix_multiline_comments(lines)
	text = "".join(lines)
	text = fix_brace_newlines(text)
	return text


def format_file(path, clang_format, style_file, check_only=False):
	if not os.path.isfile(path):
		raise FileNotFoundError(f"{path}: no such file")
	with open(path, "r", encoding="utf-8") as current:
		original = current.read()
	try:
		result = subprocess.run(
			[clang_format, "-style=file:" + style_file, path],
			check=True,
			stdout=subprocess.PIPE,
			stderr=subprocess.PIPE,
			text=True,
		)
	except subprocess.CalledProcessError as exc:
		raise RuntimeError(
			f"clang-format failed for {path}: {exc.stderr.strip() or exc}"
		) from exc
	indented = apply_indent_rules(result.stdout)
	if original == indented:
		return False
	if check_only:
		show_diff(path, original, indented)
		return True
	with tempfile.NamedTemporaryFile("w", encoding="utf-8", delete=False, dir=os.path.dirname(path)) as tmp:
		tmp.write(indented)
		temp_name = tmp.name
	os.replace(temp_name, path)
	return True


def main():
	args = parse_args()
	if args.print_config:
		print(CLANG_FORMAT_CONFIG)
		return 0
	if args.version:
		print(VERSION)
		return 0

	# Check for clang-format early to fail fast
	clang_format = args.clang_format
	if not shutil.which(clang_format):
		print(
			f"clang-format-radare2: cannot find clang-format ({clang_format})",
			file=sys.stderr,
		)
		return 1

	temp_dir = tempfile.gettempdir()
	style_file = os.path.join(temp_dir, "clang-format-radare2.tmp")
	try:
		with open(style_file, "w", encoding="utf-8") as f:
			f.write(CLANG_FORMAT_CONFIG)
		paths = args.files
		if args.auto:
			paths = load_auto_files()
			if not paths:
				print(
					"clang-format-radare2: auto mode whitelist is empty or missing",
					file=sys.stderr,
				)
				return 1
		if not paths:
			print("clang-format-radare2: no input files", file=sys.stderr)
			return 1
		files = expand_targets(paths)
		if not files:
			print("clang-format-radare2: no input files", file=sys.stderr)
			return 1
		exit_code = 0
		for path in files:
			try:
				changed = format_file(path, clang_format, style_file, args.no_update)
				if args.no_update and changed:
					exit_code = 1
			except (FileNotFoundError, RuntimeError, OSError) as err:
				print(f"clang-format-radare2: {err}", file=sys.stderr)
				exit_code = 1
	finally:
		try:
			os.remove(style_file)
		except OSError:
			pass
	return exit_code


if __name__ == "__main__":
	sys.exit(main())
