Skip to content
22 changes: 20 additions & 2 deletions crates/vm/src/builtins/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use super::{
use crate::common::lock::LazyLock;
use crate::object::{Traverse, TraverseFn};
use crate::{
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyRefExact, PyResult,
AsObject, Context, Py, PyExact, PyObject, PyObjectRef, PyPayload, PyRef, PyRefExact, PyResult,
TryFromObject, atomic_func,
builtins::{
PyTuple,
Expand Down Expand Up @@ -681,7 +681,10 @@ impl Py<PyDict> {
let self_exact = self.exact_dict(vm);
let other_exact = other.exact_dict(vm);
if self_exact && other_exact {
self.entries.get_chain(&other.entries, vm, key)
// SAFETY: exact_dict checks passed
let self_exact = unsafe { PyExact::ref_unchecked(self) };
let other_exact = unsafe { PyExact::ref_unchecked(other) };
self_exact.get_chain_exact(other_exact, key, vm)
} else if let Some(value) = self.get_item_opt(key, vm)? {
Ok(Some(value))
} else {
Expand All @@ -690,6 +693,21 @@ impl Py<PyDict> {
}
}

impl PyExact<PyDict> {
/// Look up `key` in `self`, falling back to `other`.
/// Both dicts must be exact `dict` types (enforced by `PyExact`).
pub(crate) fn get_chain_exact<K: DictKey + ?Sized>(
&self,
other: &Self,
key: &K,
vm: &VirtualMachine,
) -> PyResult<Option<PyObjectRef>> {
debug_assert!(self.class().is(vm.ctx.types.dict_type));
debug_assert!(other.class().is(vm.ctx.types.dict_type));
self.entries.get_chain(&other.entries, vm, key)
}
}

// Implement IntoIterator so that we can easily iterate dictionaries from rust code.
impl IntoIterator for PyDictRef {
type Item = (PyObjectRef, PyObjectRef);
Expand Down
13 changes: 13 additions & 0 deletions crates/vm/src/builtins/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,19 @@ pub struct PyRangeIterator {
length: usize,
}

impl PyRangeIterator {
/// Advance and return next value without going through the iterator protocol.
#[inline]
pub(crate) fn next_fast(&self) -> Option<isize> {
let index = self.index.fetch_add(1);
if index < self.length {
Some(self.start + (index as isize) * self.step)
} else {
None
}
}
}

impl PyPayload for PyRangeIterator {
#[inline]
fn class(ctx: &Context) -> &'static Py<PyType> {
Expand Down
14 changes: 14 additions & 0 deletions crates/vm/src/builtins/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1584,6 +1584,20 @@ impl AsMapping for PyStr {
impl AsNumber for PyStr {
fn as_number() -> &'static PyNumberMethods {
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
add: Some(|a, b, vm| {
let Some(a) = a.downcast_ref::<PyStr>() else {
return Ok(vm.ctx.not_implemented());
};
let Some(b) = b.downcast_ref::<PyStr>() else {
return Ok(vm.ctx.not_implemented());
};
let bytes = a.as_wtf8().py_add(b.as_wtf8());
Ok(unsafe {
let kind = a.kind() | b.kind();
PyStr::new_str_unchecked(bytes.into(), kind)
}
.to_pyobject(vm))
}),
remainder: Some(|a, b, vm| {
if let Some(a) = a.downcast_ref::<PyStr>() {
a.__mod__(b.to_owned(), vm).to_pyresult(vm)
Expand Down
Loading