Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions bot/exts/filtering/_filter_lists/image_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import typing

import aiohttp
from pydis_core.utils.logging import get_logger

from bot.exts.filtering._filter_context import Event, FilterContext
from bot.exts.filtering._filter_lists.filter_list import FilterList, ListType
from bot.exts.filtering._filters.filter import Filter
from bot.exts.filtering._filters.image_hash import ImageHashFilter
from bot.exts.filtering._image_hash import RhodiumAPIError, get_image_hash
from bot.exts.filtering._settings import ActionSettings

if typing.TYPE_CHECKING:
from bot.exts.filtering.filtering import Filtering

log = get_logger(__name__)

_MAX_IMAGE_SIZE = 5_000_000


class ImageHashesList(FilterList[ImageHashFilter]):
"""A list of perceptual image hashes that should trigger filtering when matched."""

name = "image_hash"

def __init__(self, filtering_cog: Filtering):
super().__init__()
filtering_cog.subscribe(self, Event.MESSAGE)

def get_filter_type(self, content: str) -> type[Filter]:
"""Get a subclass of filter matching the filter list and the filter's content."""
return ImageHashFilter

@property
def filter_types(self) -> set[type[Filter]]:
"""Return the types of filters used by this list."""
return {ImageHashFilter}

async def actions_for(
self, ctx: FilterContext
) -> tuple[ActionSettings | None, list[str], dict[ListType, list[Filter]]]:
"""Dispatch the given event to the list's filters, and return actions to take and messages to relay to mods."""
if not ctx.attachments:
return None, [], {}

image_hashes = []
for attachment in ctx.attachments:
if (
attachment.content_type is None
or not attachment.content_type.startswith("image")
or attachment.size > _MAX_IMAGE_SIZE
):
continue

try:
image_hash = await get_image_hash(attachment.url)
except aiohttp.ClientError:
log.exception("Unhandled aiohttp exception while getting image hash")
continue
except RhodiumAPIError as e:
log.exception("Rhodium API error: %s", e)
continue
except TimeoutError:
log.exception("Timed out getting image hash")
continue

image_hashes.append(image_hash)

if not image_hashes:
return None, [], {}

trigger_ctx = ctx.replace(content=image_hashes)
triggers = await self[ListType.DENY].filter_list_result(trigger_ctx)
if not triggers:
return None, [], {ListType.DENY: triggers}

actions = self[ListType.DENY].merge_actions(triggers)
messages = []
for filter_ in triggers:
distance = ctx.filter_info.get(filter_, "?")
messages.append(
f"{filter_.id} (`{filter_.content}` distance `{distance}`)"
f" - {filter_.description or '*No description*'}"
)
return actions, messages, {ListType.DENY: triggers}
42 changes: 42 additions & 0 deletions bot/exts/filtering/_filters/image_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import re

from discord.ext.commands import BadArgument

from bot.exts.filtering._filter_context import FilterContext
from bot.exts.filtering._filters.filter import Filter
from bot.exts.filtering._image_hash import HASH_DISTANCE_THRESHOLD, signed_i64_to_u64

_HEX_RE = re.compile(r"^(?:0x)?([0-9a-fA-F]{1,16})$")


class ImageHashFilter(Filter):
"""A filter which matches image perceptual hashes represented as hexadecimal values."""

name = "image_hash"

async def triggered_on(self, ctx: FilterContext) -> bool:
"""Search for a perceptual hash match within a given context of attachment hashes."""
candidate_hash = int(self.content, 16)

for image_hash in ctx.content:
normalized_image_hash = signed_i64_to_u64(image_hash)
distance = int.bit_count(normalized_image_hash ^ candidate_hash)
if distance <= HASH_DISTANCE_THRESHOLD:
ctx.matches.append(f"{normalized_image_hash:016x}")
ctx.filter_info[self] = str(distance)
return True
return False

@classmethod
async def process_input(cls, content: str, description: str) -> tuple[str, str]:
"""
Process the content and description into a form which will work with the filtering.

A BadArgument should be raised if the content can't be used.
"""
match = _HEX_RE.fullmatch(content.strip())
if not match:
raise BadArgument("Image hash content must be hexadecimal (optionally prefixed with `0x`).")

normalized = f"{int(match.group(1), 16):016x}"
return normalized, description
104 changes: 0 additions & 104 deletions bot/exts/filtering/_filters/unique/image.py

This file was deleted.

38 changes: 38 additions & 0 deletions bot/exts/filtering/_image_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

from bot import instance
from bot.constants import Keys, URLs

# Maximum number of seconds to wait for Rhodium API.
_TIMEOUT = 5
# Maximum perceptual hash difference for a positive prediction.
HASH_DISTANCE_THRESHOLD = 4


class RhodiumAPIError(Exception):
"""Exception raised when the Rhodium API returns an error."""


async def get_image_hash(image_url: str) -> int:
"""Return the signed i64 perceptual hash for an image URL from Rhodium."""
async with instance.http_session.post(
url=URLs.rhodium_api,
headers={"Authorization": f"Bearer {Keys.rhodium}"},
json={"url": image_url},
timeout=_TIMEOUT,
) as response:
if response.status != 200:
contents = await response.text()
raise RhodiumAPIError(f"Rhodium API returned status code {response.status}: {contents}")

response_data = await response.json()
return response_data["i64"]


def signed_i64_to_hex(value: int) -> str:
"""Convert a signed 64-bit integer to a normalized lowercase 16-char hexadecimal string."""
return f"{value & ((1 << 64) - 1):016x}"


def signed_i64_to_u64(value: int) -> int:
"""Convert a signed 64-bit integer into its unsigned 64-bit representation."""
return value & ((1 << 64) - 1)
Loading