class Solution:
def countTriplets(self, arr: List[int]) -> int:
N = len(arr)
res = 0
for i in range(N - 1):
for j in range(i + 1, N):
for k in range(j, N):
a = b = 0
for idx in range(i, j):
a ^= arr[idx]
for idx in range(j, k + 1):
b ^= arr[idx]
if a == b:
res += 1
return resclass Solution:
def countTriplets(self, arr: List[int]) -> int:
N = len(arr)
res = 0
for i in range(N - 1):
a = 0
for j in range(i + 1, N):
a ^= arr[j - 1]
b = 0
for k in range(j, N):
b ^= arr[k]
if a == b:
res += 1
return resclass Solution:
def countTriplets(self, arr: List[int]) -> int:
N = len(arr)
res = 0
for i in range(N - 1):
cur_xor = arr[i]
for k in range(i + 1, N):
cur_xor ^= arr[k]
if cur_xor == 0:
res += k - i
return resclass Solution:
def countTriplets(self, arr: List[int]) -> int:
N = len(arr)
res = prefix = 0
count = defaultdict(int) # number of prefixes
index_sum = defaultdict(int) # sum of indices with that prefix
count[0] = 1
for i in range(N):
prefix ^= arr[i]
if prefix in count:
res += i * count[prefix] - index_sum[prefix]
count[prefix] += 1
index_sum[prefix] += i + 1
return res